DataParallel Not working with timm Model


I am trying to train a model which is based on timm library.

My main looks as:
model = DPTDepthModel(

loss_function = MSE()

optimizer = torch.optim.Adam(model.parameters(), lr=0.001, )

if torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    model = nn.DataParallel(model, device_ids = [0, 1])
    print("Model and Loss Function are on multiple GPUs now")

for epoch in range(0, 5):
    newmodel=train(model,train_loader,loss_function,optimizer,epoch), "model"+ str(epoch) +".pt")    


File “”, line 482, in main2
File “”, line 381, in train
File “/home/talha/anaconda3/envs/SemNet2/lib/python3.6/site-packages/torch/nn/modules/”, line 1102, in _call_impl
return forward_call(*input, **kwargs)
File “/home/talha/anaconda3/envs/SemNet2/lib/python3.6/site-packages/torch/nn/parallel/”, line 168, in forward
outputs = self.parallel_apply(replicas, inputs, kwargs)
File “/home/talha/anaconda3/envs/SemNet2/lib/python3.6/site-packages/torch/nn/parallel/”, line 178, in parallel_apply
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
File “/home/talha/anaconda3/envs/SemNet2/lib/python3.6/site-packages/torch/nn/parallel/”, line 86, in parallel_apply
File “/home/talha/anaconda3/envs/SemNet2/lib/python3.6/site-packages/torch/”, line 434, in reraise
raise exception
RuntimeError: Caught RuntimeError in replica 1 on device 1.
Original Traceback (most recent call last):
File “/home/talha/anaconda3/envs/SemNet2/lib/python3.6/site-packages/torch/nn/parallel/”, line 61, in _worker
output = module(*input, **kwargs)
File “/home/talha/anaconda3/envs/SemNet2/lib/python3.6/site-packages/torch/nn/modules/”, line 1102, in _call_impl
return forward_call(*input, **kwargs)
File “/home/talha/SemAttNet-Thesis-Code-main/dpt/”, line 165, in forward
inv_depth = super().forward(x).squeeze(dim=1)
File “/home/talha/SemAttNet-Thesis-Code-main/dpt/”, line 78, in forward
glob = self.pretrained.model.forward_flex(x)
File “/home/talha/SemAttNet-Thesis-Code-main/dpt/”, line 175, in forward_flex
x = self.patch_embed.backbone(x)
File “/home/talha/anaconda3/envs/SemNet2/lib/python3.6/site-packages/torch/nn/modules/”, line 1102, in _call_impl
return forward_call(*input, **kwargs)
File “/home/talha/anaconda3/envs/SemNet2/lib/python3.6/site-packages/timm/models/”, line 418, in forward
x = self.forward_features(x)
File “/home/talha/anaconda3/envs/SemNet2/lib/python3.6/site-packages/timm/models/”, line 412, in forward_features
x = self.stem(x)
File “/home/talha/anaconda3/envs/SemNet2/lib/python3.6/site-packages/torch/nn/modules/”, line 1102, in _call_impl
return forward_call(*input, **kwargs)
File “/home/talha/anaconda3/envs/SemNet2/lib/python3.6/site-packages/torch/nn/modules/”, line 141, in forward
input = module(input)
File “/home/talha/anaconda3/envs/SemNet2/lib/python3.6/site-packages/torch/nn/modules/”, line 1102, in _call_impl
return forward_call(*input, **kwargs)
File “/home/talha/anaconda3/envs/SemNet2/lib/python3.6/site-packages/timm/models/layers/”, line 72, in forward
x = F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0! (when checking argument for argument weight in method wrapper__cudnn_convolution)

Based on the stacktrace the error is raised in this line of code as it seems that x = self.patch_embed.backbone(x) creates the device mismatch internally.
I’m not familiar with the dpt repository, so don’t know how these methods are used exactly.
nn.DataParallel will chunk the input batch to the forward pass and will send each slice to the corresponding GPU which already holds the model clones.
This would mean that you should call directly to the model and not into its model.forward method.
Also, make sure that no new tensors in the forward are created on the default GPU. If newly created tensors are needed, then use the .device attribute or any parameter or the input tensor.