From 1bafb719996fd5536ea648e853f9a00db31bce9f Mon Sep 17 00:00:00 2001 From: drprojects Date: Fri, 19 Jul 2024 14:39:50 +0200 Subject: [PATCH] fix for issues #139 #141 --- src/data/data.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/data/data.py b/src/data/data.py index 65e8a8a8..d0ab1c23 100644 --- a/src/data/data.py +++ b/src/data/data.py @@ -926,11 +926,13 @@ def from_data_list(cls, data_list, follow_batch=None, exclude_keys=None): batch = super().from_data_list( data_list, follow_batch=follow_batch, exclude_keys=exclude_keys) - # Dirty trick: manually convert 'sub' to a proper ClusterBatch - # and 'obj' to a proper InstanceBatch. + # PyG does not know how to batch Cluster and InstanceData + # objects. So the 'sub' and 'obj' attributes will contain lists + # of such objects. We now need to manually convert these to + # proper ClusterBatch and InstanceBatch. # Note we will need to do the same in `get_example` to avoid # breaking PyG Batch mechanisms - if batch.is_super and isinstance(batch.sub, Cluster): + if batch.is_super: batch.sub = ClusterBatch.from_list(batch.sub) if batch.obj is not None: batch.obj = InstanceBatch.from_list(batch.obj)