Hi,
I have an IndexError
when trying to train a Graph Neural Network (GNN). As a brief background: I have a bi-modal network, where one site (cell) is a GNN and the other one (drug) a simple NN. My model class for the GNN looks like this
class GraphTab_v1(torch.nn.Module):
def __init__(self):
super(GraphTab_v1, self).__init__()
torch.manual_seed(12345)
# Drug branch. Not important for this question
# Cell branch.
self.cell_emb = Sequential('x, edge_index', [
(GCNConv(in_channels=4, out_channels=256), 'x, edge_index -> x'), # TODO: try GATConv() vs GCNConv()
nn.ReLU(inplace=True),
(GCNConv(in_channels=256, out_channels=256), 'x, edge_index -> x'),
nn.ReLU(inplace=True),
(global_mean_pool, 'x, batch -> x'),
nn.Linear(256, 128),
nn.ReLU(),
nn.Linear(128, 128),
nn.ReLU()
])
# ...
def forward(self, cell, drug):
# Drug branch. Not important for this question.
# Cell branch.
cell_emb = self.cell_emb(cell.x, cell.edge_index)
I have a batch size of 1000
and each graph has the following topology
Data(x=[858, 4], edge_index=[2, 83126])
For the batch size of 1000
this means per batch I have this single disconnected graph which holds all graphs in the batch
DataBatch(x=[858000, 4], edge_index=[2, 83126000], batch=[858000], ptr=[1001])
Now when I run in my train()
method the following
preds = self.model(cell, drug.float())
I get the following error
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
/Users/BLABLA.ipynb Cell 53' in <cell line: 22>()
12 optimizer = torch.optim.Adam(params=model.parameters(), lr=LR) # TODO: include weight_decay of lr
14 build_model = BuildModel(model=model,
15 criterion=loss_func,
16 optimizer=optimizer,
(...)
19 test_loader=test_loader,
20 val_loader=val_loader)
---> 22 build_model.train()
/Users/BLABLA.ipynb Cell 52' in BuildModel.train(self)
32 print(f"targets : {targets[:10]}")
34 # Models predictions of the ic50s for a batch of cell-lines and drugs
---> 35 preds = self.model(cell, drug.float())
36 print(targets)
File /users/BLABLA/python3.10/site-packages/torch/nn/modules/module.py:1110, in Module._call_impl(self, *input, **kwargs)
1106 # If we don't have any hooks, we want to skip the rest of the logic in
1107 # this function, and just call forward.
1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1109 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110 return forward_call(*input, **kwargs)
1111 # Do not call functions when jit is used
1112 full_backward_hooks, non_full_backward_hooks = [], []
File ~/BLABLA.py:77, in GraphTab_v1.forward(self, cell, drug)
70 print(f"drug_emb: {drug_emb}")
72 # cell_gnn_out = self.cell_gnn(cell.x, cell.edge_index)
73 # print(f"cell_gnn_out: {cell_gnn_out}")
74 # # Readout layer.
75 # cell_gnn = global_mean_pool(cell_gnn_out, cell) # [batch_size, hidden_channels]
76 # cell_emb = self.cell_emb(cell_gnn)
---> 77 cell_emb = self.cell_emb(cell.x, cell.edge_index)
78 print(f"cell_emb: {cell_emb}")
81 # ----------------------------------------------------- #
82 # Concatenate the outputs of the cell and drug branches #
83 # ----------------------------------------------------- #
File /users/BLABLA/python3.10/site-packages/torch/nn/modules/module.py:1110, in Module._call_impl(self, *input, **kwargs)
1106 # If we don't have any hooks, we want to skip the rest of the logic in
1107 # this function, and just call forward.
1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1109 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110 return forward_call(*input, **kwargs)
1111 # Do not call functions when jit is used
1112 full_backward_hooks, non_full_backward_hooks = [], []
File /var/BLABLA/tmpqgyjz410.py:18, in Sequential_59c724.forward(self, x, edge_index)
16 def forward(self, x, edge_index):
17 """"""
---> 18 x = self.module_0(x, edge_index)
19 x = self.module_1(x)
20 x = self.module_2(x, edge_index)
File /users/BLABLA/python3.10/site-packages/torch/nn/modules/module.py:1110, in Module._call_impl(self, *input, **kwargs)
1106 # If we don't have any hooks, we want to skip the rest of the logic in
1107 # this function, and just call forward.
1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1109 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110 return forward_call(*input, **kwargs)
1111 # Do not call functions when jit is used
1112 full_backward_hooks, non_full_backward_hooks = [], []
File /users/BLABLA/python3.10/site-packages/torch_geometric/nn/conv/gcn_conv.py:172, in GCNConv.forward(self, x, edge_index, edge_weight)
170 cache = self._cached_edge_index
171 if cache is None:
--> 172 edge_index, edge_weight = gcn_norm( # yapf: disable
173 edge_index, edge_weight, x.size(self.node_dim),
174 self.improved, self.add_self_loops)
175 if self.cached:
176 self._cached_edge_index = (edge_index, edge_weight)
File /users/BLABLA/python3.10/site-packages/torch_geometric/nn/conv/gcn_conv.py:64, in gcn_norm(edge_index, edge_weight, num_nodes, improved, add_self_loops, dtype)
61 edge_weight = tmp_edge_weight
63 row, col = edge_index[0], edge_index[1]
---> 64 deg = scatter_add(edge_weight, col, dim=0, dim_size=num_nodes)
65 deg_inv_sqrt = deg.pow_(-0.5)
66 deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0)
File /users/BLABLA/python3.10/site-packages/torch_scatter/scatter.py:29, in scatter_add(src, index, dim, out, dim_size)
26 def scatter_add(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
27 out: Optional[torch.Tensor] = None,
28 dim_size: Optional[int] = None) -> torch.Tensor:
---> 29 return scatter_sum(src, index, dim, out, dim_size)
File /users/BLABLA/python3.10/site-packages/torch_scatter/scatter.py:21, in scatter_sum(src, index, dim, out, dim_size)
19 size[dim] = int(index.max()) + 1
20 out = torch.zeros(size, dtype=src.dtype, device=src.device)
---> 21 return out.scatter_add_(dim, index, src)
22 else:
23 return out.scatter_add_(dim, index, src)
RuntimeError: index 858000 is out of bounds for dimension 0 with size 858000
I know that a similar question got asked in this question, however, unfortunately this doesn’t help me, or at least I don’t know how it can help me.
It looks to me that it tries to access the index which is the length of nodes in the full graph. But the latest index it can access in 858000 - 1. So maybe it’s starts counting the indexes from 1 and not from 0 which is why the last index is 858000. But that’s just an idea.
I appreciate any help!