Sizes of tensors must match except in dimension 0. Expected size 914 but got size 531 for tensor number 1 in the list

I’m manipulating Pytorch geometric data, which are graphs having a different number of nodes, that forces the pytorch geometric data attribute such (data. x, data. edge_attr) to have different dimensions being different graphs when I load the data. This problem occurs in the matches that nécessite data dimension equality, are there other solutions without paddings?
The error

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-53-dee57c4cc7a6> in <cell line: 0>()
----> 1 for batch in mini_train_loader:
      2     print(batch)  # Should now print a batch object instead of crashing
      3     break



/usr/local/lib/python3.11/dist-packages/torch_geometric/data/collate.py in _collate(key, values, data_list, stores, increment)
    203             out = elem.new(storage).resize_(*shape)
    204 
--> 205         value = torch.cat(values, dim=cat_dim or 0, out=out)
    206 
    207         if increment and isinstance(value, Index) and values[0].is_sorted:

RuntimeError: Sizes of tensors must match except in dimension 0. Expected size 914 but got size 531 for tensor number 1 in the list.

the code

mini_data_pairs
[(Data(x=[29, 86], edge_index=[2, 62], edge_attr=[62, 24], num_nodes=29),
  Data(x=[183, 26], edge_index=[2, 1469], edge_attr=[1469, 1], num_nodes=183, adjacency_matrix=[183, 183]),
  tensor(0.)),
 (Data(x=[24, 86], edge_index=[2, 48], edge_attr=[48, 24], num_nodes=24),
  Data(x=[263, 26], edge_index=[2, 2397], edge_attr=[2397, 1], num_nodes=263, adjacency_matrix=[263, 263]),
  tensor(0.)),
 (Data(x=[43, 86], edge_index=[2, 88], edge_attr=[88, 24], num_nodes=43),
  Data(x=[367, 26], edge_index=[2, 2974], edge_attr=[2974, 1], num_nodes=367, adjacency_matrix=[367, 367]),
  tensor(1.)),
 (Data(x=[27, 86], edge_index=[2, 58], edge_attr=[58, 24], num_nodes=27),
  Data(x=[503, 26], edge_index=[2, 4268], edge_attr=[4268, 1], num_nodes=503, adjacency_matrix=[503, 503]),
  tensor(1.)),
 (Data(x=[30, 86], edge_index=[2, 62], edge_attr=[62, 24], num_nodes=30),
  Data(x=[385, 26], edge_index=[2, 3095], edge_attr=[3095, 1], num_nodes=385, adjacency_matrix=[385, 385]),
  tensor(0.)),
 (Data(x=[54, 86], edge_index=[2, 122], edge_attr=[122, 24], num_nodes=54),
  Data(x=[531, 26], edge_index=[2, 3843], edge_attr=[3843, 1], num_nodes=531, adjacency_matrix=[531, 531]),
  tensor(0.)),
 (Data(x=[17, 86], edge_index=[2, 36], edge_attr=[36, 24], num_nodes=17),
  Data(x=[177, 26], edge_index=[2, 916], edge_attr=[916, 1], num_nodes=177, adjacency_matrix=[177, 177]),
  tensor(0.)),
 (Data(x=[46, 86], edge_index=[2, 98], edge_attr=[98, 24], num_nodes=46),
  Data(x=[1356, 26], edge_index=[2, 9633], edge_attr=[9633, 1], num_nodes=1356, adjacency_matrix=[1356, 1356]),
  tensor(1.)),
 (Data(x=[31, 86], edge_index=[2, 66], edge_attr=[66, 24], num_nodes=31),
  Data(x=[477, 26], edge_index=[2, 3011], edge_attr=[3011, 1], num_nodes=477, adjacency_matrix=[477, 477]),
  tensor(1.)),
 (Data(x=[18, 86], edge_index=[2, 38], edge_attr=[38, 24], num_nodes=18),
  Data(x=[914, 26], edge_index=[2, 8499], edge_attr=[8499, 1], num_nodes=914, adjacency_matrix=[914, 914]),
  tensor(0.))]
def collate_fn(batch):
    """
    Collate function to create batches for DataLoader.

    Args:
    - batch (list): List of (drug_graph, protein_graph, label) tuples.

    Returns:
    - Tuple of (list of drug graphs, list of protein graphs, tensor of labels).
    """
    drug_graphs, protein_graphs, labels = zip(*batch)


    # Batch graphs using PyG's batching mechanism
    drug_batch = Batch.from_data_list(drug_graphs)
    protein_batch = Batch.from_data_list(protein_graphs)

    # Convert labels to a single tensor for efficiency
    labels_tensor = torch.stack(labels)  # More efficient than creating a new tensor

    return list(drug_graphs), list(protein_graphs), labels_tensor
mini_train_loader = DataLoader(mini_data_pairs, batch_size=2, shuffle=True, collate_fn=collate_fn)

Your custom collate_fn doesn’t seem to be used since the error is raised from:

/usr/local/lib/python3.11/dist-packages/torch_geometric/data/collate.py in _collate(key, values, data_list, stores, increment)
    203             out = elem.new(storage).resize_(*shape)
    204 
--> 205         value = torch.cat(values, dim=cat_dim or 0, out=out)

pointing to the _collate from PyG.

the _collate is used by the Batch.from_data_list() function , this is the hall erreur :

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-25-389cb9d1134b> in <cell line: 0>()
      9     return drug_batch, protein_batch, labels
     10 dataloader = DataLoader(mini_data_pairs, batch_size=2, collate_fn=coll , shuffle=True)
---> 11 for batch in dataloader:
     12     print(batch)
     13     break

8 frames
/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py in __next__(self)
    699                 # TODO(https://github.com/pytorch/pytorch/issues/76750)
    700                 self._reset()  # type: ignore[call-arg]
--> 701             data = self._next_data()
    702             self._num_yielded += 1
    703             if (

/usr/local/lib/python3.11/dist-packages/torch/utils/data/dataloader.py in _next_data(self)
    755     def _next_data(self):
    756         index = self._next_index()  # may raise StopIteration
--> 757         data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    758         if self._pin_memory:
    759             data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)

/usr/local/lib/python3.11/dist-packages/torch/utils/data/_utils/fetch.py in fetch(self, possibly_batched_index)
     53         else:
     54             data = self.dataset[possibly_batched_index]
---> 55         return self.collate_fn(data)

/usr/local/lib/python3.11/dist-packages/torch_geometric/loader/dataloader.py in __call__(self, batch)
     45             return type(elem)(*(self(s) for s in zip(*batch)))
     46         elif isinstance(elem, Sequence) and not isinstance(elem, str):
---> 47             return [self(s) for s in zip(*batch)]
     48 
     49         raise TypeError(f"DataLoader found invalid type: '{type(elem)}'")

/usr/local/lib/python3.11/dist-packages/torch_geometric/loader/dataloader.py in <listcomp>(.0)
     45             return type(elem)(*(self(s) for s in zip(*batch)))
     46         elif isinstance(elem, Sequence) and not isinstance(elem, str):
---> 47             return [self(s) for s in zip(*batch)]
     48 
     49         raise TypeError(f"DataLoader found invalid type: '{type(elem)}'")

/usr/local/lib/python3.11/dist-packages/torch_geometric/loader/dataloader.py in __call__(self, batch)
     25         elem = batch[0]
     26         if isinstance(elem, BaseData):
---> 27             return Batch.from_data_list(
     28                 batch,
     29                 follow_batch=self.follow_batch,

/usr/local/lib/python3.11/dist-packages/torch_geometric/data/batch.py in from_data_list(cls, data_list, follow_batch, exclude_keys)
     95         Will exclude any keys given in :obj:`exclude_keys`.
     96         """
---> 97         batch, slice_dict, inc_dict = collate(
     98             cls,
     99             data_list=data_list,

/usr/local/lib/python3.11/dist-packages/torch_geometric/data/collate.py in collate(cls, data_list, increment, add_batch, follow_batch, exclude_keys)
    107 
    108             # Collate attributes into a unified representation:
--> 109             value, slices, incs = _collate(attr, values, data_list, stores,
    110                                            increment)
    111 

/usr/local/lib/python3.11/dist-packages/torch_geometric/data/collate.py in _collate(key, values, data_list, stores, increment)
    203             out = elem.new(storage).resize_(*shape)
    204 
--> 205         value = torch.cat(values, dim=cat_dim or 0, out=out)
    206 
    207         if increment and isinstance(value, Index) and values[0].is_sorted:

RuntimeError: Sizes of tensors must match except in dimension 0. Expected size 263 but got size 531 for tensor number 1 in the list.