DataParallel expecting all tensors in the same device

Hello, I am using DataParallel in a similar way as shown in this tutorial.
I run into the following error:

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! (when checking argument for argument mat2 in method wrapper_mm)

This happens right after the first forward pass.
The model architecture is built from Pytorch Geometric temporal


#How the data is loaded
train_dataset, test_dataset = temporal_signal_split(dataset, train_ratio=0.7)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

#Model definition
class RecurrentGCN(torch.nn.Module):
    def __init__(self, node_features):
        super(RecurrentGCN, self).__init__()
        torch.manual_seed(1234567)
        #from Pytorch Geometric Temporal
        self.recurrent1 = SomeGCN(node_features, 32, 1)
        self.fc1 = torch.nn.Linear(32, 1)
        self.sigmoid = torch.sigmoid
        self.dropout = torch.nn.Dropout(0.2)
    
    def forward(self, data):
        x, edge_index, edge_weight = data.x, data.edge_index, data.edge_attr
        x = torch.flatten(x, start_dim=1)
        h_0 = self.recurrent1(x, edge_index, edge_weight)
        h = F.relu(h_0)
        h = self.fc1(h)
        h = self.sigmoid(h)
        h = self.dropout(h)
        return h 

#DataParallel instance
model = RecurrentGCN(node_features = n_features)
    if torch.cuda.device_count() > 1:
        print("Available/CUDA_VISIBLE_DEVICES", os.environ["CUDA_VISIBLE_DEVICES"])
        print("Device count", torch.cuda.device_count())
        # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
        model = DataParallel(model, device_ids=[0, 1])
    model.to(device)

#during train loop
for snapshot in train_dataset:
            snapshot = snapshot.to(device)
            y_hat = model(snapshot) #Error happens here

The tutorial seems straightforward so I do not see what am I doing wrong here.
Last note, the code runs without problems in a single GPU.

Thanks for the help!

Could you check the .device attribute of all tensors inside the forward method and compare them to the .device of the parameters?
They should match and based on your description I would guess that the passed input might be a custom class, which is not recognized by nn.DataParallel and thus not split into the chunks for each GPU.

I think you are right. I put a breakpoint on the forward method and I can see for example fc1 and recurrent1 parameters are on cuda:1 whilst the attribute tensors are on cuda:0.
How could I try to solve this?

Make sure that tensors are passed to the model, which can be split in their batch dimension.

1 Like

I have updated the training loop to send tensors to the model that can be split in their batch dimension like this:

for snapshot in train_dataset:
            x, edge_index, edge_weight = snapshot.x, snapshot.edge_index, snapshot.edge_attr
            x = torch.flatten(x, start_dim=1).to(device)
            edge_index = edge_index.to(device)
            edge_weight = edge_weight.to(device)
            y_hat = model(x, edge_index, edge_weight)

This solves the previous error, but it results in a new one:

Exception has occurred: IndexError (note: full exception trace is shown but execution is paused at: _run_module_as_main)
Caught IndexError in replica 0 on device 0.
Original Traceback (most recent call last):
File “/home/acg384/miniconda3/envs/pytorch_test/lib/python3.9/site-packages/torch/nn/parallel/parallel_apply.py”, line 61, in _worker
output = module(*input, **kwargs)
File “/home/acg384/miniconda3/envs/pytorch_test/lib/python3.9/site-packages/torch/nn/modules/module.py”, line 1102, in _call_impl
return forward_call(*input, **kwargs)
File “/mnt/raid0/users/acg384/workspace/code/LSTM_ddp2.py”, line 49, in forward
h_0 = self.recurrent1(x, edge_index, edge_weight)
File “/home/acg384/miniconda3/envs/pytorch_test/lib/python3.9/site-packages/torch/nn/modules/module.py”, line 1102, in _call_impl
return forward_call(*input, **kwargs)
File “/home/acg384/miniconda3/envs/pytorch_test/lib/python3.9/site-packages/torch_geometric_temporal/nn/recurrent/gconv_gru.py”, line 163, in forward
Z = self._calculate_update_gate(X, edge_index, edge_weight, H)
File “/home/acg384/miniconda3/envs/pytorch_test/lib/python3.9/site-packages/torch_geometric_temporal/nn/recurrent/gconv_gru.py”, line 120, in _calculate_update_gate
Z = self.conv_x_z(X, edge_index, edge_weight)
File “/home/acg384/miniconda3/envs/pytorch_test/lib/python3.9/site-packages/torch/nn/modules/module.py”, line 1102, in _call_impl
return forward_call(*input, **kwargs)
File “/home/acg384/miniconda3/envs/pytorch_test/lib/python3.9/site-packages/torch_geometric/nn/conv/cheb_conv.py”, line 143, in forward
edge_index, norm = self.norm(edge_index, x.size(self.node_dim),
File “/home/acg384/miniconda3/envs/pytorch_test/lib/python3.9/site-packages/torch_geometric/nn/conv/cheb_conv.py”, line 110, in norm
edge_index, edge_weight = remove_self_loops(edge_index, edge_weight)
File “/home/acg384/miniconda3/envs/pytorch_test/lib/python3.9/site-packages/torch_geometric/utils/loop.py”, line 36, in remove_self_loops
mask = edge_index[0] != edge_index[1]
IndexError: index 1 is out of bounds for dimension 0 with size 1
File “/home/acg384/miniconda3/envs/pytorch_test/lib/python3.9/site-packages/torch/_utils.py”, line 434, in reraise
raise exception
File “/home/acg384/miniconda3/envs/pytorch_test/lib/python3.9/site-packages/torch/nn/parallel/parallel_apply.py”, line 86, in parallel_apply
output.reraise()
File “/home/acg384/miniconda3/envs/pytorch_test/lib/python3.9/site-packages/torch/nn/parallel/data_parallel.py”, line 178, in parallel_apply
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
File “/home/acg384/miniconda3/envs/pytorch_test/lib/python3.9/site-packages/torch/nn/parallel/data_parallel.py”, line 168, in forward
outputs = self.parallel_apply(replicas, inputs, kwargs)
File “/home/acg384/miniconda3/envs/pytorch_test/lib/python3.9/site-packages/torch/nn/modules/module.py”, line 1102, in _call_impl
return forward_call(*input, **kwargs)
File “/mnt/raid0/users/acg384/workspace/code/LSTM_ddp2.py”, line 176, in
y_hat = model(x, edge_index, edge_weight)
File “/home/acg384/miniconda3/envs/pytorch_test/lib/python3.9/runpy.py”, line 87, in _run_code
exec(code, run_globals)
File “/home/acg384/miniconda3/envs/pytorch_test/lib/python3.9/runpy.py”, line 97, in _run_module_code
_run_code(code, mod_globals, init_globals,
File “/home/acg384/miniconda3/envs/pytorch_test/lib/python3.9/runpy.py”, line 268, in run_path
return _run_module_code(code, init_globals, run_name,
File “/home/acg384/miniconda3/envs/pytorch_test/lib/python3.9/runpy.py”, line 87, in _run_code
exec(code, run_globals)
File “/home/acg384/miniconda3/envs/pytorch_test/lib/python3.9/runpy.py”, line 197, in _run_module_as_main (Current frame)
return _run_code(code, main_globals, None,

This seems to be triggered by some internals within pytorch geometric.