RunTimeError: index i is out of bounds for dimension 0 with size i

Hi,

I have an IndexError when trying to train a Graph Neural Network (GNN). As a brief background: I have a bi-modal network, where one site (cell) is a GNN and the other one (drug) a simple NN. My model class for the GNN looks like this

class GraphTab_v1(torch.nn.Module):
    def __init__(self):
        super(GraphTab_v1, self).__init__()
        torch.manual_seed(12345)

        # Drug branch. Not important for this question

        # Cell branch.
        self.cell_emb = Sequential('x, edge_index', [
            (GCNConv(in_channels=4, out_channels=256), 'x, edge_index -> x'), # TODO: try GATConv() vs GCNConv()
            nn.ReLU(inplace=True),
            (GCNConv(in_channels=256, out_channels=256), 'x, edge_index -> x'),
            nn.ReLU(inplace=True),
            (global_mean_pool, 'x, batch -> x'), 
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU()      
        ])

        # ...

    def forward(self, cell, drug):
        # Drug branch. Not important for this question.
        # Cell branch.
        cell_emb = self.cell_emb(cell.x, cell.edge_index)

I have a batch size of 1000 and each graph has the following topology

Data(x=[858, 4], edge_index=[2, 83126])

For the batch size of 1000 this means per batch I have this single disconnected graph which holds all graphs in the batch

DataBatch(x=[858000, 4], edge_index=[2, 83126000], batch=[858000], ptr=[1001])

Now when I run in my train() method the following

preds = self.model(cell, drug.float())

I get the following error

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
/Users/BLABLA.ipynb Cell 53' in <cell line: 22>()
     12 optimizer = torch.optim.Adam(params=model.parameters(), lr=LR) # TODO: include weight_decay of lr
     14 build_model = BuildModel(model=model,
     15                          criterion=loss_func,
     16                          optimizer=optimizer,
   (...)
     19                          test_loader=test_loader,
     20                          val_loader=val_loader)
---> 22 build_model.train()

/Users/BLABLA.ipynb Cell 52' in BuildModel.train(self)
     32 print(f"targets      : {targets[:10]}")
     34 # Models predictions of the ic50s for a batch of cell-lines and drugs
---> 35 preds = self.model(cell, drug.float())
     36 print(targets)

File /users/BLABLA/python3.10/site-packages/torch/nn/modules/module.py:1110, in Module._call_impl(self, *input, **kwargs)
   1106 # If we don't have any hooks, we want to skip the rest of the logic in
   1107 # this function, and just call forward.
   1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1109         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110     return forward_call(*input, **kwargs)
   1111 # Do not call functions when jit is used
   1112 full_backward_hooks, non_full_backward_hooks = [], []

File ~/BLABLA.py:77, in GraphTab_v1.forward(self, cell, drug)
     70 print(f"drug_emb: {drug_emb}")
     72 # cell_gnn_out = self.cell_gnn(cell.x, cell.edge_index)
     73 # print(f"cell_gnn_out: {cell_gnn_out}")
     74 # # Readout layer.
     75 # cell_gnn = global_mean_pool(cell_gnn_out, cell) # [batch_size, hidden_channels]
     76 # cell_emb = self.cell_emb(cell_gnn)
---> 77 cell_emb = self.cell_emb(cell.x, cell.edge_index)
     78 print(f"cell_emb: {cell_emb}")
     81 # ----------------------------------------------------- #
     82 # Concatenate the outputs of the cell and drug branches #
     83 # ----------------------------------------------------- #

File /users/BLABLA/python3.10/site-packages/torch/nn/modules/module.py:1110, in Module._call_impl(self, *input, **kwargs)
   1106 # If we don't have any hooks, we want to skip the rest of the logic in
   1107 # this function, and just call forward.
   1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1109         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110     return forward_call(*input, **kwargs)
   1111 # Do not call functions when jit is used
   1112 full_backward_hooks, non_full_backward_hooks = [], []

File /var/BLABLA/tmpqgyjz410.py:18, in Sequential_59c724.forward(self, x, edge_index)
     16 def forward(self, x, edge_index):
     17     """"""
---> 18     x = self.module_0(x, edge_index)
     19     x = self.module_1(x)
     20     x = self.module_2(x, edge_index)
File /users/BLABLA/python3.10/site-packages/torch/nn/modules/module.py:1110, in Module._call_impl(self, *input, **kwargs)
   1106 # If we don't have any hooks, we want to skip the rest of the logic in
   1107 # this function, and just call forward.
   1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1109         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110     return forward_call(*input, **kwargs)
   1111 # Do not call functions when jit is used
   1112 full_backward_hooks, non_full_backward_hooks = [], []

File /users/BLABLA/python3.10/site-packages/torch_geometric/nn/conv/gcn_conv.py:172, in GCNConv.forward(self, x, edge_index, edge_weight)
    170 cache = self._cached_edge_index
    171 if cache is None:
--> 172     edge_index, edge_weight = gcn_norm(  # yapf: disable
    173         edge_index, edge_weight, x.size(self.node_dim),
    174         self.improved, self.add_self_loops)
    175     if self.cached:
    176         self._cached_edge_index = (edge_index, edge_weight)

File /users/BLABLA/python3.10/site-packages/torch_geometric/nn/conv/gcn_conv.py:64, in gcn_norm(edge_index, edge_weight, num_nodes, improved, add_self_loops, dtype)
     61     edge_weight = tmp_edge_weight
     63 row, col = edge_index[0], edge_index[1]
---> 64 deg = scatter_add(edge_weight, col, dim=0, dim_size=num_nodes)
     65 deg_inv_sqrt = deg.pow_(-0.5)
     66 deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0)

File /users/BLABLA/python3.10/site-packages/torch_scatter/scatter.py:29, in scatter_add(src, index, dim, out, dim_size)
     26 def scatter_add(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
     27                 out: Optional[torch.Tensor] = None,
     28                 dim_size: Optional[int] = None) -> torch.Tensor:
---> 29     return scatter_sum(src, index, dim, out, dim_size)

File /users/BLABLA/python3.10/site-packages/torch_scatter/scatter.py:21, in scatter_sum(src, index, dim, out, dim_size)
     19         size[dim] = int(index.max()) + 1
     20     out = torch.zeros(size, dtype=src.dtype, device=src.device)
---> 21     return out.scatter_add_(dim, index, src)
     22 else:
     23     return out.scatter_add_(dim, index, src)

RuntimeError: index 858000 is out of bounds for dimension 0 with size 858000

I know that a similar question got asked in this question, however, unfortunately this doesn’t help me, or at least I don’t know how it can help me.

It looks to me that it tries to access the index which is the length of nodes in the full graph. But the latest index it can access in 858000 - 1. So maybe it’s starts counting the indexes from 1 and not from 0 which is why the last index is 858000. But that’s just an idea.

I appreciate any help!