I have updated the training loop to send tensors to the model that can be split in their batch dimension like this:
for snapshot in train_dataset:
x, edge_index, edge_weight = snapshot.x, snapshot.edge_index, snapshot.edge_attr
x = torch.flatten(x, start_dim=1).to(device)
edge_index = edge_index.to(device)
edge_weight = edge_weight.to(device)
y_hat = model(x, edge_index, edge_weight)
This solves the previous error, but it results in a new one:
Exception has occurred: IndexError (note: full exception trace is shown but execution is paused at: _run_module_as_main)
Caught IndexError in replica 0 on device 0.
Original Traceback (most recent call last):
File “/home/acg384/miniconda3/envs/pytorch_test/lib/python3.9/site-packages/torch/nn/parallel/parallel_apply.py”, line 61, in _worker
output = module(*input, **kwargs)
File “/home/acg384/miniconda3/envs/pytorch_test/lib/python3.9/site-packages/torch/nn/modules/module.py”, line 1102, in _call_impl
return forward_call(*input, **kwargs)
File “/mnt/raid0/users/acg384/workspace/code/LSTM_ddp2.py”, line 49, in forward
h_0 = self.recurrent1(x, edge_index, edge_weight)
File “/home/acg384/miniconda3/envs/pytorch_test/lib/python3.9/site-packages/torch/nn/modules/module.py”, line 1102, in _call_impl
return forward_call(*input, **kwargs)
File “/home/acg384/miniconda3/envs/pytorch_test/lib/python3.9/site-packages/torch_geometric_temporal/nn/recurrent/gconv_gru.py”, line 163, in forward
Z = self._calculate_update_gate(X, edge_index, edge_weight, H)
File “/home/acg384/miniconda3/envs/pytorch_test/lib/python3.9/site-packages/torch_geometric_temporal/nn/recurrent/gconv_gru.py”, line 120, in _calculate_update_gate
Z = self.conv_x_z(X, edge_index, edge_weight)
File “/home/acg384/miniconda3/envs/pytorch_test/lib/python3.9/site-packages/torch/nn/modules/module.py”, line 1102, in _call_impl
return forward_call(*input, **kwargs)
File “/home/acg384/miniconda3/envs/pytorch_test/lib/python3.9/site-packages/torch_geometric/nn/conv/cheb_conv.py”, line 143, in forward
edge_index, norm = self.norm(edge_index, x.size(self.node_dim),
File “/home/acg384/miniconda3/envs/pytorch_test/lib/python3.9/site-packages/torch_geometric/nn/conv/cheb_conv.py”, line 110, in norm
edge_index, edge_weight = remove_self_loops(edge_index, edge_weight)
File “/home/acg384/miniconda3/envs/pytorch_test/lib/python3.9/site-packages/torch_geometric/utils/loop.py”, line 36, in remove_self_loops
mask = edge_index[0] != edge_index[1]
IndexError: index 1 is out of bounds for dimension 0 with size 1
File “/home/acg384/miniconda3/envs/pytorch_test/lib/python3.9/site-packages/torch/_utils.py”, line 434, in reraise
raise exception
File “/home/acg384/miniconda3/envs/pytorch_test/lib/python3.9/site-packages/torch/nn/parallel/parallel_apply.py”, line 86, in parallel_apply
output.reraise()
File “/home/acg384/miniconda3/envs/pytorch_test/lib/python3.9/site-packages/torch/nn/parallel/data_parallel.py”, line 178, in parallel_apply
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
File “/home/acg384/miniconda3/envs/pytorch_test/lib/python3.9/site-packages/torch/nn/parallel/data_parallel.py”, line 168, in forward
outputs = self.parallel_apply(replicas, inputs, kwargs)
File “/home/acg384/miniconda3/envs/pytorch_test/lib/python3.9/site-packages/torch/nn/modules/module.py”, line 1102, in _call_impl
return forward_call(*input, **kwargs)
File “/mnt/raid0/users/acg384/workspace/code/LSTM_ddp2.py”, line 176, in
y_hat = model(x, edge_index, edge_weight)
File “/home/acg384/miniconda3/envs/pytorch_test/lib/python3.9/runpy.py”, line 87, in _run_code
exec(code, run_globals)
File “/home/acg384/miniconda3/envs/pytorch_test/lib/python3.9/runpy.py”, line 97, in _run_module_code
_run_code(code, mod_globals, init_globals,
File “/home/acg384/miniconda3/envs/pytorch_test/lib/python3.9/runpy.py”, line 268, in run_path
return _run_module_code(code, init_globals, run_name,
File “/home/acg384/miniconda3/envs/pytorch_test/lib/python3.9/runpy.py”, line 87, in _run_code
exec(code, run_globals)
File “/home/acg384/miniconda3/envs/pytorch_test/lib/python3.9/runpy.py”, line 197, in _run_module_as_main (Current frame)
return _run_code(code, main_globals, None,
This seems to be triggered by some internals within pytorch geometric.