PyTorch multi-gpu problem with nn.DataParallel

Hi everyone,

I’m trying to use nn.DataParallel to have a multi-gpu training, but I encountered the following error. I had a look at the various threads, but I wasn’t able to fix the issue:

RuntimeError: Expected tensor for argument #1 ‘input’ to have the same device as tensor for argument #2 ‘weight’; but device 1 does not equal 0 (while checking arguments for cudnn_convolution)

I’ve used the following lines before importing torch to limit the visible GPUs (as I do for the training on a single gpu):

os.environ["CUDA_VISIBLE_DEVICES"] = "0, 2, 3"
os.getcwd()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

The model it’s used in this way (with weights from another training):

model = ResNet()
model.load_state_dict(torch.load('rot_weights.pt'))
model = nn.DataParallel(model, device_ids=[0, 2, 3]).to(device)
optimizer = optim.Adam(model.parameters(),
                        lr=learning_rate,
                        betas=(0.9, 0.999),
                        eps=1e-08,
                        weight_decay=0,
                        amsgrad=False)

The model is a ResNet-18 from which I want just the feature extraction part, modified for my regression problem:

class ResNet(nn.Module):
    def __init__(self):
        super(ResNet, self).__init__()
        self.model =  pretrainedmodels.__dict__['resnet18'](pretrained='imagenet')
        self.regression_layer = nn.Sequential(nn.Linear(512, 6))

    def forward(self, x):
        batch_size ,_,_,_ = x.shape #taking out batch_size from input image
        x = self.model.features(x)
        x = torch.nn.functional.adaptive_avg_pool2d(x,1).reshape(batch_size,-1) # then reshaping the batch_size
        x = self.regression_layer(x)
        x = compute_rotation_matrix_from_ortho6d(x.view(batch_size, -1))

        return x
    
    def compute_rotation_matrix_l2_loss(self, gt_rotation_matrix, predict_rotation_matrix):
        loss_function = nn.MSELoss()
        loss = loss_function(predict_rotation_matrix, gt_rotation_matrix)

        return loss

    def compute_rotation_matrix_geodesic_loss(self, gt_rotation_matrix, predict_rotation_matrix):
        theta = compute_geodesic_distance_from_two_matrices(gt_rotation_matrix, predict_rotation_matrix)
        error = theta.mean()

        return error

Any suggestion would be really appreciated!

If you mask the GPUs via CUDA_VISIBLE_DEVICES, the device ids inside the script will be mapped to [0, nb_gpus], which would mean you should use 0, 1, 2 in the script.
Could you change it and see, if it would solve the issue?

Hey @ptrblck, thanks for the prompt reply. Unfortunately, I still get the same error:

RuntimeError: Expected tensor for argument #1 'input' to have the same device as tensor for argument #2 'weight'; but device 1 does not equal 0 (while checking arguments for cudnn_convolution)

Full error below:

File "/home/chiara/my_workspace/RegressionCNN_rot/main.py", line 162, in main
    train_loss_epoch, train_error_epoch = training(model, train_loader)
  File "/home/chiara/my_workspace/RegressionCNN_rot/main.py", line 94, in training
    out_rot_mat = model(image_batch)
  File "/home/chiara/anaconda3/envs/python3.7/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/chiara/anaconda3/envs/python3.7/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 167, in forward
    outputs = self.parallel_apply(replicas, inputs, kwargs)
  File "/home/chiara/anaconda3/envs/python3.7/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 177, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
  File "/home/chiara/anaconda3/envs/python3.7/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 86, in parallel_apply
    output.reraise()
  File "/home/chiara/anaconda3/envs/python3.7/lib/python3.7/site-packages/torch/_utils.py", line 429, in reraise
    raise self.exc_type(msg)
RuntimeError: Caught RuntimeError in replica 1 on device 1.
Original Traceback (most recent call last):
  File "/home/chiara/anaconda3/envs/python3.7/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 61, in _worker
    output = module(*input, **kwargs)
  File "/home/chiara/anaconda3/envs/python3.7/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/chiara/my_workspace/RegressionCNN_rot/model.py", line 19, in forward
    x = self.model.features(x)
  File "/home/chiara/anaconda3/envs/python3.7/lib/python3.7/site-packages/pretrainedmodels/models/torchvision_models.py", line 322, in features
    x = self.conv1(input)
  File "/home/chiara/anaconda3/envs/python3.7/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/chiara/anaconda3/envs/python3.7/lib/python3.7/site-packages/torch/nn/modules/conv.py", line 399, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/home/chiara/anaconda3/envs/python3.7/lib/python3.7/site-packages/torch/nn/modules/conv.py", line 396, in _conv_forward
    self.padding, self.dilation, self.groups)
RuntimeError: Expected tensor for argument #1 'input' to have the same device as tensor for argument #2 'weight'; but device 1 does not equal 0 (while checking arguments for cudnn_convolution)

I have also tried without specifying the devide_ids, but nothing changed.

Thanks for the update. I’ve rechecked your initial code and remembered that I’ve seen a similar issue before in pretrainedmodels and guess you might also be hitting this issue.
It seems the repository is breaking nn.DataParallel, so you could either use another repo (e.g. torchvision.models) or use DistributedDataParallel instead (I haven’t verified that it’s working with pretrainedmodels, but it might).

Thanks for the suggestions @ptrblck. I’m trying to use torchvision.models but I think I need to modify my model class.

When I replace the line

self.model = pretrainedmodels.__dict__['resnet18'](pretrained='imagenet')

with

self.model = models.resnet18(pretrained=True)

I get the following error:

AttributeError: ‘ResNet’ object has no attribute ‘features’

Yes, you are right that some modifications would be needed, in case you depend on the (missing) .features attribute.
The torchvision implementation can be found here and you’ll see that the layers (or blocks) are called directly instead of using a features/classifier split.

You could create a custom model by reusing the torchvision.models.resnet18 and overriding the forward method.
Here is an example how to do it:

class MyResNet18(nn.Module):
    def __init__(self, resnet):
        super().__init__()
        # create features branch using https://github.com/pytorch/vision/blob/2a52c2dca73513d0d0c3e2a505aed05e5b9aa792/torchvision/models/resnet.py#L230-L246
        self.features = nn.Sequential(
            resnet.conv1,
            resnet.bn1,
            resnet.relu,
            resnet.maxpool,
            resnet.layer1,
            resnet.layer2,
            resnet.layer3,
            resnet.layer4
        )
        self.avgpool = resnet.avgpool
        self.fc = resnet.fc
        
    def _forward_impl(self, x: torch.Tensor) -> torch.Tensor:
        # See note [TorchScript super()]
        x = self.features(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self._forward_impl(x)
    

# create standard model and reuse in custom one
model = models.resnet18()
print(model)
custom_model = MyResNet18(model)        

# check outputs
x = torch.randn(2, 3, 224, 224)
out = model(x)
custom_out = custom_model(x)

# compare outputs to make sure the model works as intended
print((out - custom_out).abs().max())
> tensor(0., grad_fn=<MaxBackward1>)

print(custom_model.features)

Hi @ptrblck. It seems that following your suggestion the error was solved, thanks a lot!