RuntimeError when multiple GPUs are used: Expected tensor for argument #1 'input' to have the same device as tensor for argument #2 'weight'; but device 1 does not equal 0 (while checking arguments for cudnn_convolution)

Hello! I would like to compute the gradients of a tensor matrix with respect to another set of gradients. Below code works fine on a single GPU. However, when multiple GPUs are used, the error occurs. Can anyone share with me how to handle this? Thanks!

May I ask if you could have a look? Thank you very much! @ptrblck @albanD

My code:

#model
model = models.resnet50(pretrained=True)
model.fc = nn.Linear(2048, num_classes)    

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model)
model = model.to(device) 

#T is a tensor, which I would like to compute its gradients based on another set of gradients
T = torch.from_numpy(numpy.random.rand(num_classes, num_classes)).float()
T = T.requires_grad_(True).to(device)

#do not want to update the weights of the original model, so get another copy of the model
net = copy.deepcopy(model) 

#copy net's weights for later gradients computation manually
original_weights = OrderedDict()  
for param, weights in net.named_parameters():
    if not weights.requires_grad:
        print(param)   
    else:
        original_weights[param] = copy.deepcopy(weights)        
original_weights_keys = tuple(original_weights.keys())

#data_s is images, target_s is labels
with torch.enable_grad():
    logits = net(data_s)     
    #pre1 is the (target_s)-th row of the tensor matrix T, basically, loss_ is crossentropy loss
    pre1 = T[torch.cuda.LongTensor(target_s.data)]
    pre2 = torch.mul(F.softmax(logits, dim=1), pre1)  
    loss_ = -(torch.log(pre2.sum(1))).sum(0)/float(len(target_s))
    
    #manually compute gradients
    net_grads = torch.autograd.grad(loss_, net.parameters(), retain_graph=True, create_graph=True, allow_unused=True) 

    #see _del_nested_attr() and _set_nested_attr() in https://discuss.pytorch.org/t/how-to-calculate-2nd-derivative-of-a-likelihood-function/15085/30 
    for param, grad in zip(original_weights_keys, net_grads):                
        if grad is None: 
            continue
        else:
            new_weight = original_weights[param] - learning_rate * grad
            # remove old Parameter
            _del_nested_attr(net, param.split("."))
            # set the Tensor with history
            _set_nested_attr(net, param.split("."), new_weight)
    
    #error occurs from below line, but it is fine when a single GPU is used.
    pre = net(data_g)
    loss = criteria(pre, target_g)
        
    T_grads = torch.autograd.grad(loss, T, allow_unused=True)

with torch.set_grad_enabled(False):    
    T -= learning_rate * T_grads[0]

Error:
File “train_ours_meta_clth1m.py”, line 207, in update_T
pre = net(data_g)
File “/root/miniconda3/envs/pytorch/lib/python3.7/site-packages/torch/nn/modules/module.py”, line 493, in call
result = self.forward(*input, **kwargs)
File “/root/miniconda3/envs/pytorch/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py”, line 152, in forward
outputs = self.parallel_apply(replicas, inputs, kwargs)
File “/root/miniconda3/envs/pytorch/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py”, line 162, in parallel_apply
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
File “/root/miniconda3/envs/pytorch/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py”, line 83, in parallel_apply
raise output
File “/root/miniconda3/envs/pytorch/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py”, line 59, in _worker
output = module(*input, **kwargs)
File “/root/miniconda3/envs/pytorch/lib/python3.7/site-packages/torch/nn/modules/module.py”, line 493, in call
result = self.forward(*input, **kwargs)
File “/root/miniconda3/envs/pytorch/lib/python3.7/site-packages/torchvision/models/resnet.py”, line 192, in forward
x = self.conv1(x)
File “/root/miniconda3/envs/pytorch/lib/python3.7/site-packages/torch/nn/modules/module.py”, line 493, in call
result = self.forward(*input, **kwargs)
File “/root/miniconda3/envs/pytorch/lib/python3.7/site-packages/torch/nn/modules/conv.py”, line 338, in forward
self.padding, self.dilation, self.groups)
RuntimeError: Expected tensor for argument #1 ‘input’ to have the same device as tensor for argument #2 ‘weight’; but device 1 does not equal 0 (while checking arguments for cudnn_convolution)

Could you solve this problem?