Hello! I would like to compute the gradients of a tensor matrix with respect to another set of gradients. Below code works fine on a single GPU. However, when multiple GPUs are used, the error occurs. Can anyone share with me how to handle this? Thanks!

May I ask if you could have a look? Thank you very much! @ptrblck @albanD

My code:

```
#model
model = models.resnet50(pretrained=True)
model.fc = nn.Linear(2048, num_classes)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if torch.cuda.device_count() > 1:
model = nn.DataParallel(model)
model = model.to(device)
#T is a tensor, which I would like to compute its gradients based on another set of gradients
T = torch.from_numpy(numpy.random.rand(num_classes, num_classes)).float()
T = T.requires_grad_(True).to(device)
#do not want to update the weights of the original model, so get another copy of the model
net = copy.deepcopy(model)
#copy net's weights for later gradients computation manually
original_weights = OrderedDict()
for param, weights in net.named_parameters():
if not weights.requires_grad:
print(param)
else:
original_weights[param] = copy.deepcopy(weights)
original_weights_keys = tuple(original_weights.keys())
#data_s is images, target_s is labels
with torch.enable_grad():
logits = net(data_s)
#pre1 is the (target_s)-th row of the tensor matrix T, basically, loss_ is crossentropy loss
pre1 = T[torch.cuda.LongTensor(target_s.data)]
pre2 = torch.mul(F.softmax(logits, dim=1), pre1)
loss_ = -(torch.log(pre2.sum(1))).sum(0)/float(len(target_s))
#manually compute gradients
net_grads = torch.autograd.grad(loss_, net.parameters(), retain_graph=True, create_graph=True, allow_unused=True)
#see _del_nested_attr() and _set_nested_attr() in https://discuss.pytorch.org/t/how-to-calculate-2nd-derivative-of-a-likelihood-function/15085/30
for param, grad in zip(original_weights_keys, net_grads):
if grad is None:
continue
else:
new_weight = original_weights[param] - learning_rate * grad
# remove old Parameter
_del_nested_attr(net, param.split("."))
# set the Tensor with history
_set_nested_attr(net, param.split("."), new_weight)
#error occurs from below line, but it is fine when a single GPU is used.
pre = net(data_g)
loss = criteria(pre, target_g)
T_grads = torch.autograd.grad(loss, T, allow_unused=True)
with torch.set_grad_enabled(False):
T -= learning_rate * T_grads[0]
```

Error:

File “train_ours_meta_clth1m.py”, line 207, in update_T

pre = net(data_g)

File “/root/miniconda3/envs/pytorch/lib/python3.7/site-packages/torch/nn/modules/module.py”, line 493, in **call**

result = self.forward(*input, **kwargs)

File “/root/miniconda3/envs/pytorch/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py”, line 152, in forward

outputs = self.parallel_apply(replicas, inputs, kwargs)

File “/root/miniconda3/envs/pytorch/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py”, line 162, in parallel_apply

return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])

File “/root/miniconda3/envs/pytorch/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py”, line 83, in parallel_apply

raise output

File “/root/miniconda3/envs/pytorch/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py”, line 59, in _worker

output = module(*input, **kwargs)

File “/root/miniconda3/envs/pytorch/lib/python3.7/site-packages/torch/nn/modules/module.py”, line 493, in **call**

result = self.forward(*input, **kwargs)

File “/root/miniconda3/envs/pytorch/lib/python3.7/site-packages/torchvision/models/resnet.py”, line 192, in forward

x = self.conv1(x)

File “/root/miniconda3/envs/pytorch/lib/python3.7/site-packages/torch/nn/modules/module.py”, line 493, in **call**

result = self.forward(*input, **kwargs)

File “/root/miniconda3/envs/pytorch/lib/python3.7/site-packages/torch/nn/modules/conv.py”, line 338, in forward

self.padding, self.dilation, self.groups)

RuntimeError: Expected tensor for argument #1 ‘input’ to have the same device as tensor for argument #2 ‘weight’; but device 1 does not equal 0 (while checking arguments for cudnn_convolution)