Device mismatch error with nn.DataParallel

I’m training a modified ResNet on multiple GPUs. Here is the residual block for the ResNet:

class BasicBlockWOutput(nn.Module):
    expansion = 1

    def __init__(self, in_channels, channels, params, stride=1):
        super(BasicBlockWOutput, self).__init__()
        add_output = params[0]
        num_classes = params[1]
        input_size = params[2]
        self.output_id = params[3]

        self.depth = 2

        layers = nn.ModuleList()

        conv_layer = []
        conv_layer.append(nn.Conv2d(in_channels, channels, kernel_size=3, stride=stride, padding=1, bias=False))
        conv_layer.append(nn.BatchNorm2d(channels))
        conv_layer.append(nn.ReLU())
        conv_layer.append(nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1, bias=False))
        conv_layer.append(nn.BatchNorm2d(channels))

        layers.append(nn.Sequential(*conv_layer))

        shortcut = nn.Sequential()

        if stride != 1 or in_channels != self.expansion*channels:
            shortcut = nn.Sequential(
                nn.Conv2d(in_channels, self.expansion*channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*channels)
            )

        layers.append(shortcut)
        layers.append(nn.ReLU())

        self.layers = layers

        if add_output:
            self.output = af.InternalClassifier(input_size, self.expansion*channels, num_classes) 
            self.no_output = False

        else:
            self.output = None
            self.forward = self.only_forward
            self.no_output = True
            
    def forward(self, x):
        fwd = self.layers[0](x) # conv layers
        fwd = fwd + self.layers[1](x) # shortcut
        return self.layers[2](fwd), 1, self.output(fwd)         # output layers for this module
    
    def only_output(self, x):
        fwd = self.layers[0](x) # conv layers
        fwd = fwd + self.layers[1](x) # shortcut
        fwd = self.layers[2](fwd) # activation
        out = self.output(fwd)         # output layers for this module
        return out
    
    def only_forward(self, x):
        fwd = self.layers[0](x) # conv layers
        fwd = fwd + self.layers[1](x) # shortcut
        return self.layers[2](fwd), 0, None # activation

When I’m running it on multiple GPUs I’m having device mismatch issues:

RuntimeError: Caught RuntimeError in replica 1 on device 1.
Original Traceback (most recent call last):
File “/export/mlrg/sshekhar/anaconda3/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py”, line 60, in _worker
output = module(*input, **kwargs)
File “/export/mlrg/sshekhar/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py”, line 722, in _call_impl
result = self.forward(*input, **kwargs)
File “/export/mlrg/sshekhar/XAI/Shallow-Deep-Networks-gpub/architectures/SDNs/ResNet_SDN.py”, line 159, in forward
fwd, is_output, output = layer(fwd)
File “/export/mlrg/sshekhar/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py”, line 722, in _call_impl
result = self.forward(*input, **kwargs)
File “/export/mlrg/sshekhar/XAI/Shallow-Deep-Networks-gpub/architectures/SDNs/ResNet_SDN.py”, line 68, in only_forward
fwd = self.layers0 # conv layers
File “/export/mlrg/sshekhar/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py”, line 722, in _call_impl
result = self.forward(*input, **kwargs)
File “/export/mlrg/sshekhar/anaconda3/lib/python3.7/site-packages/torch/nn/modules/container.py”, line 117, in forward
input = module(input)
File “/export/mlrg/sshekhar/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py”, line 722, in _call_impl
result = self.forward(*input, **kwargs)
File “/export/mlrg/sshekhar/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py”, line 419, in forward
return self._conv_forward(input, self.weight)
File “/export/mlrg/sshekhar/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py”, line 416, in _conv_forward
self.padding, self.dilation, self.groups)
RuntimeError: Expected tensor for argument #1 ‘input’ to have the same device as tensor for argument #2 ‘weight’; but device 1 does not equal 0 (while checking arguments for cudnn_convolution)

It seems like your input is on GPU 1, but your network is on GPU 0. From the error trace, it seems like the issue stems from how the only_forward function is defined and passed as a reference to self.forward. This looks similar to the issue described in https://github.com/pytorch/pytorch/issues/8637 - check out the suggestion by Ssnl about why duplicating non-tensor objects in DataParallel could lead to these device mismatch errors.

So I was able to resolve my issue based on Omkar’s suggestion to look at https://github.com/pytorch/pytorch/issues/8637

I changed the following: Instead of binding forward to the only_forward method, I directly call it inside the forward method:

        if add_output:
            self.output = af.InternalClassifier(input_size, self.expansion*channels, num_classes) 
            self.no_output = False
        else:
            self.output = None
            #self.forward = self.only_forward
            self.no_output = True
            
    def forward(self, x):

        if self.no_output:
            return self.only_forward(x)        
        else:
            fwd = self.layers[0](x) # conv layers
            fwd = fwd + self.layers[1](x) # shortcut
            return self.layers[2](fwd), 1, self.output(fwd)         # output layers for this module

    
    def only_output(self, x):
        fwd = self.layers[0](x) # conv layers
        fwd = fwd + self.layers[1](x) # shortcut
        fwd = self.layers[2](fwd) # activation
        out = self.output(fwd)         # output layers for this module
        return out
    
    def only_forward(self, x):
        fwd = self.layers[0](x) # conv layers
        fwd = fwd + self.layers[1](x) # shortcut
        return self.layers[2](fwd), 0, None # activation
1 Like