Hello everyone.
I’m facing a little problem when pyg creates batch on my custom bipartite data class. (Occurs when batchsize >= 2)
my version:
torch 1.9.0
torch_geometric 2.0.2
pytorch-cluster 1.5.9 py39_torch_1.9.0_cu102 pyg
pytorch-scatter 2.0.9 py39_torch_1.9.0_cu102 pyg
pytorch-sparse 0.6.12 py39_torch_1.9.0_cu102 pyg
pytorch-spline-conv 1.2.1 py39_torch_1.9.0_cu102 pyg
Traceback (most recent call last):
File "/data1/Jay/miniconda3/envs/env1/lib/python3.9/runpy.py", line 197, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/data1/Jay/miniconda3/envs/env1/lib/python3.9/runpy.py", line 87, in _run_code
exec(code, run_globals)
File "/data1/Jay/.vscode-server/extensions/ms-python.python-2022.20.2/pythonFiles/lib/python/debugpy/adapter/../../debugpy/launcher/../../debugpy/__main__.py", line 39, in <module>
cli.main()
File "/data1/Jay/.vscode-server/extensions/ms-python.python-2022.20.2/pythonFiles/lib/python/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 430, in main
run()
File "/data1/Jay/.vscode-server/extensions/ms-python.python-2022.20.2/pythonFiles/lib/python/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 284, in run_file
runpy.run_path(target, run_name="__main__")
File "/data1/Jay/.vscode-server/extensions/ms-python.python-2022.20.2/pythonFiles/lib/python/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 321, in run_path
return _run_module_code(code, init_globals, run_name,
File "/data1/Jay/.vscode-server/extensions/ms-python.python-2022.20.2/pythonFiles/lib/python/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 135, in _run_module_code
_run_code(code, mod_globals, init_globals,
File "/data1/Jay/.vscode-server/extensions/ms-python.python-2022.20.2/pythonFiles/lib/python/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 124, in _run_code
exec(code, run_globals)
File "/data1/Jay/reproduce/xytest_train.py", line 157, in <module>
for batch_No, batch in enumerate(train_dataloader):
File "/data1/Jay/miniconda3/envs/env1/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 521, in __next__
data = self._next_data()
File "/data1/Jay/miniconda3/envs/env1/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 561, in _next_data
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
File "/data1/Jay/miniconda3/envs/env1/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 47, in fetch
return self.collate_fn(data)
File "/data1/Jay/miniconda3/envs/env1/lib/python3.9/site-packages/torch_geometric/loader/dataloader.py", line 34, in __call__
return [self(s) for s in zip(*batch)]
File "/data1/Jay/miniconda3/envs/env1/lib/python3.9/site-packages/torch_geometric/loader/dataloader.py", line 34, in <listcomp>
return [self(s) for s in zip(*batch)]
File "/data1/Jay/miniconda3/envs/env1/lib/python3.9/site-packages/torch_geometric/loader/dataloader.py", line 19, in __call__
return Batch.from_data_list(batch, self.follow_batch,
File "/data1/Jay/miniconda3/envs/env1/lib/python3.9/site-packages/torch_geometric/data/batch.py", line 69, in from_data_list
batch, slice_dict, inc_dict = collate(
File "/data1/Jay/miniconda3/envs/env1/lib/python3.9/site-packages/torch_geometric/data/collate.py", line 33, in collate
out = cls(_base_cls=data_list[0].__class__) # Dynamic inheritance.
File "/data1/Jay/miniconda3/envs/env1/lib/python3.9/site-packages/torch_geometric/data/batch.py", line 41, in __call__
return super(DynamicInheritance, new_cls).__call__(*args, **kwargs)
File "/data1/Jay/reproduce/utility.py", line 324, in __init__
self.var_feats = torch.tensor(var_feats.clone(), dtype=torch.float32)
AttributeError: 'NoneType' object has no attribute 'clone'
The source code of my class is
from torch_geometric.data import Data
class BipartiteData_Lite(Data):
"""custom bipartite data"""
def __init__(
self,
cstr_feats,
edge_indices,
edge_values,
var_feats,
):
super(BipartiteData_Lite, self).__init__()
self.var_feats = torch.tensor(var_feats.clone(), dtype=torch.float32)
self.cstr_feats = torch.tensor(cstr_feats, dtype=torch.float32)
self.edge_index = torch.tensor(edge_indices, dtype=torch.long)
self.edge_attr = torch.tensor(edge_values, dtype=torch.float32)
self.num_nodes = self.var_feats.shape[1] + self.cstr_feats.shape[1]
def __inc__(self, key, value, *args, **kwargs):
if key == "edge_index":
return torch.tensor([[self.cstr_feats.size(0)], [self.var_feats.size(0)]])
else:
return super().__inc__(key, value, *args, **kwargs)
class BipartiteDataset_Lite(Dataset):
...
part of xytest_train.py (training script) is as follows
training_set = BipartiteDataset_Lite(training_set_path)
train_dataloader = DataLoader(training_set, batch_size=batch_size, shuffle=True)
for batch_No, batch in enumerate(train_dataloader):
...
It seems that, when torch_geometric tries to construct a mini-batch, after making <batchsize> instances, it calls BipartiteData_Lite.__init__() again, sending NoneTypeObject to BipartiteData_Lite.__init__()
I think this issue is from here.
File "/data1/Jay/miniconda3/envs/env1/lib/python3.9/site-packages/torch_geometric/data/batch.py", line 41, in __call__
return super(DynamicInheritance, new_cls).__call__(*args, **kwargs)
Best regards.