NoneObject occurs when pyg creates batch

Hello everyone.

I’m facing a little problem when pyg creates batch on my custom bipartite data class. (Occurs when batchsize >= 2)

my version:

torch 1.9.0
torch_geometric 2.0.2
pytorch-cluster           1.5.9           py39_torch_1.9.0_cu102    pyg
pytorch-scatter           2.0.9           py39_torch_1.9.0_cu102    pyg
pytorch-sparse            0.6.12          py39_torch_1.9.0_cu102    pyg
pytorch-spline-conv       1.2.1           py39_torch_1.9.0_cu102    pyg
Traceback (most recent call last):
  File "/data1/Jay/miniconda3/envs/env1/lib/python3.9/runpy.py", line 197, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/data1/Jay/miniconda3/envs/env1/lib/python3.9/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/data1/Jay/.vscode-server/extensions/ms-python.python-2022.20.2/pythonFiles/lib/python/debugpy/adapter/../../debugpy/launcher/../../debugpy/__main__.py", line 39, in <module>
    cli.main()
  File "/data1/Jay/.vscode-server/extensions/ms-python.python-2022.20.2/pythonFiles/lib/python/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 430, in main
    run()
  File "/data1/Jay/.vscode-server/extensions/ms-python.python-2022.20.2/pythonFiles/lib/python/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 284, in run_file
    runpy.run_path(target, run_name="__main__")
  File "/data1/Jay/.vscode-server/extensions/ms-python.python-2022.20.2/pythonFiles/lib/python/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 321, in run_path
    return _run_module_code(code, init_globals, run_name,
  File "/data1/Jay/.vscode-server/extensions/ms-python.python-2022.20.2/pythonFiles/lib/python/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 135, in _run_module_code
    _run_code(code, mod_globals, init_globals,
  File "/data1/Jay/.vscode-server/extensions/ms-python.python-2022.20.2/pythonFiles/lib/python/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 124, in _run_code
    exec(code, run_globals)
  File "/data1/Jay/reproduce/xytest_train.py", line 157, in <module>
    for batch_No, batch in enumerate(train_dataloader):
  File "/data1/Jay/miniconda3/envs/env1/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 521, in __next__
    data = self._next_data()
  File "/data1/Jay/miniconda3/envs/env1/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 561, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
  File "/data1/Jay/miniconda3/envs/env1/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 47, in fetch
    return self.collate_fn(data)
  File "/data1/Jay/miniconda3/envs/env1/lib/python3.9/site-packages/torch_geometric/loader/dataloader.py", line 34, in __call__
    return [self(s) for s in zip(*batch)]
  File "/data1/Jay/miniconda3/envs/env1/lib/python3.9/site-packages/torch_geometric/loader/dataloader.py", line 34, in <listcomp>
    return [self(s) for s in zip(*batch)]
  File "/data1/Jay/miniconda3/envs/env1/lib/python3.9/site-packages/torch_geometric/loader/dataloader.py", line 19, in __call__
    return Batch.from_data_list(batch, self.follow_batch,
  File "/data1/Jay/miniconda3/envs/env1/lib/python3.9/site-packages/torch_geometric/data/batch.py", line 69, in from_data_list
    batch, slice_dict, inc_dict = collate(
  File "/data1/Jay/miniconda3/envs/env1/lib/python3.9/site-packages/torch_geometric/data/collate.py", line 33, in collate
    out = cls(_base_cls=data_list[0].__class__)  # Dynamic inheritance.
  File "/data1/Jay/miniconda3/envs/env1/lib/python3.9/site-packages/torch_geometric/data/batch.py", line 41, in __call__
    return super(DynamicInheritance, new_cls).__call__(*args, **kwargs)
  File "/data1/Jay/reproduce/utility.py", line 324, in __init__
    self.var_feats = torch.tensor(var_feats.clone(), dtype=torch.float32)
AttributeError: 'NoneType' object has no attribute 'clone'

The source code of my class is

from torch_geometric.data import Data

class BipartiteData_Lite(Data):
    """custom bipartite data"""
    def __init__(
        self,
        cstr_feats,
        edge_indices,
        edge_values,
        var_feats,
    ):

        super(BipartiteData_Lite, self).__init__()

        self.var_feats = torch.tensor(var_feats.clone(), dtype=torch.float32)
        self.cstr_feats = torch.tensor(cstr_feats, dtype=torch.float32)
        self.edge_index = torch.tensor(edge_indices, dtype=torch.long)
        self.edge_attr = torch.tensor(edge_values, dtype=torch.float32)
        self.num_nodes = self.var_feats.shape[1] + self.cstr_feats.shape[1]

    def __inc__(self, key, value, *args, **kwargs):
        if key == "edge_index":
            return torch.tensor([[self.cstr_feats.size(0)], [self.var_feats.size(0)]])
        else:
            return super().__inc__(key, value, *args, **kwargs)

class BipartiteDataset_Lite(Dataset):
    ...

part of xytest_train.py (training script) is as follows

training_set = BipartiteDataset_Lite(training_set_path)

train_dataloader = DataLoader(training_set, batch_size=batch_size, shuffle=True)

for batch_No, batch in enumerate(train_dataloader):
    ...

It seems that, when torch_geometric tries to construct a mini-batch, after making <batchsize> instances, it calls BipartiteData_Lite.__init__() again, sending NoneTypeObject to BipartiteData_Lite.__init__()

I think this issue is from here.

  File "/data1/Jay/miniconda3/envs/env1/lib/python3.9/site-packages/torch_geometric/data/batch.py", line 41, in __call__
    return super(DynamicInheritance, new_cls).__call__(*args, **kwargs)

Best regards.

I’m not familiar enough with pytorch_geometric, but the issue seems to be raised from here.
The Data class seems to expect:

def __init__(self, x: OptTensor = None, edge_index: OptTensor = None,
                 edge_attr: OptTensor = None, y: OptTensor = None,
                 pos: OptTensor = None, **kwargs):

in its __init__ method which is also used in this tutorial.
I hope these pointers can help narrowing down the issue, as I unfortunately don’t know enough about this library to point out the error (let’s wait for an expert to chime in).

Thanks for your reply!
Now I’ve added

if edge_values is None:
    return

in __init__ method, and it’s good :slight_smile: