RuntimeError: Expected tensor for argument #1 'input' to have the same device as tensor for argument #2 'weight'; but device 1 does not equal 0 (while checking arguments for cudnn_convolution)

Hi All,
I am using torch.nn.DataParallel for the model in multiple gpu setup (4 GPUs)as defined below and find that there is some mismatch between the input and weight. I am using the following Module for a Decoder kind of a network.

class DepthDecoder(nn.Module):
    def __init__(self, num_ch_enc, scales=range(4), num_output_channels=1, max_depth=10, use_skips=True, use_bn=False):
        super(DepthDecoder, self).__init__()

        self.num_output_channels = num_output_channels
        self.max_depth = max_depth
        self.use_skips = use_skips
        self.upsample_mode = 'nearest'
        self.scales = scales

        self.num_ch_enc = num_ch_enc
        self.num_ch_dec = np.array([16, 32, 64, 128, 256])

        # decoder
        self.convs = OrderedDict()
        for i in range(4, -1, -1):
            # upconv_0
            num_ch_in = self.num_ch_enc[-1] if i == 4 else self.num_ch_dec[i + 1]
            num_ch_out = self.num_ch_dec[i]
            if use_bn:
                self.convs[("upconv", i, 0)] = ConvBlock_bn(num_ch_in, num_ch_out)
            else:
                self.convs[("upconv", i, 0)] = ConvBlock(num_ch_in, num_ch_out)

            # upconv_1
            num_ch_in = self.num_ch_dec[i]
            if self.use_skips and i > 0:
                num_ch_in += self.num_ch_enc[i - 1]
            num_ch_out = self.num_ch_dec[i]
            if use_bn:
                self.convs[("upconv", i, 1)] = ConvBlock_bn(num_ch_in, num_ch_out)
            else:
                self.convs[("upconv", i, 1)] = ConvBlock(num_ch_in, num_ch_out)

        for s in self.scales:
            # self.convs[("dispconv", s)] = Conv3x3(self.num_ch_dec[s], self.num_output_channels)
            self.convs[("depthconv", s)] = Conv3x3(self.num_ch_dec[s], self.num_output_channels)

        self.decoder = nn.ModuleList(list(self.convs.values()))
        self.sigmoid = nn.Sigmoid()

def forward(self, input_features):
        self.outputs = {}

        # decoder
        x = input_features[-1]
        for i in range(4, -1, -1):
            x = self.convs[("upconv", i, 0)](x)
            x = [upsample(x)]
            if self.use_skips and i > 0:
                x += [input_features[i - 1]]
            x = torch.cat(x, 1)
            x = self.convs[("upconv", i, 1)](x)
            if i in self.scales:
                # self.outputs[("disp", i)] = self.sigmoid(self.convs[("dispconv", i)](x))
                self.outputs[("depth", i)] = self.sigmoid(self.convs[("depthconv", i)](x)) * self.max_depth

        return self.outputs

ConvBlock and Conv3x3 are just simple convultional class performing 2D conv.

class ConvBlock(nn.Module):
    """Layer to perform a convolution followed by ELU
    """
    def __init__(self, in_channels, out_channels):
        super(ConvBlock, self).__init__()

        self.conv = Conv3x3(in_channels, out_channels)
        self.nonlin = nn.ELU(inplace=True)

    def forward(self, x):
        out = self.conv(x)
        out = self.nonlin(out)
        return out

I have seen similar post with the similar error, and what I get is this is something to do with redefinition of a function in init function, and is about putting as much as functionality in forward function. But I was quite not able to pin-point, what exactly was causing this issue in my case.

Any suggestion would be really helpful!

1 Like

Could you post an executable code snippet as the current code doesn’t define the num_ch_enc (shape and values), the upsample usage as well as the input shape?

Hi @ptrblck, Please find my response inline:

  1. define the num_ch_enc (shape and values)
    -------> Here in this case, num_ch_enc defines the number of channel which is a <class 'numpy.ndarray'> containing the value [64 64 128 256 512] of shape (5,)

  2. Input shape to forward function
    -------> Input to forward function is list of Len 5, where each element is a tensor of size:

torch.Size([8, 64, 240, 320])
torch.Size([8, 64, 120, 160])
torch.Size([8, 128, 60, 80])
torch.Size([8, 256, 30, 40])
torch.Size([8, 512, 15, 20])

Where 8 is the Batch Size (which rightly reduces to 2, when I use multiple GPUs(4) with Dataparallel)

  1. Definition of Upsample and Conv3x3
def upsample(x):
    """Upsample input tensor by a factor of 2
    """
    return F.interpolate(x, scale_factor=2, mode="nearest")
class Conv3x3(nn.Module):
    """Layer to pad and convolve input
    """
    def __init__(self, in_channels, out_channels, use_refl=True):
        super(Conv3x3, self).__init__()

        if use_refl:
            self.pad = nn.ReflectionPad2d(1)
        else:
            self.pad = nn.ZeroPad2d(1)
        self.conv = nn.Conv2d(int(in_channels), int(out_channels), 3)

    def forward(self, x):
        out = self.pad(x)
        out = self.conv(out)
        return out

Data Parallel part goes without any hassle for Encoder Part but it fails and gives this error, when it goes through the DepthDecoder network. (Output of Encoder goes-to DepthDecoder class specified before.

Thanks for the update.
The issue is raised, because self.convs is using and OrderedDict instead of an nn.ModuleDict, which will not properly register these modules.
Change it to the latter and adapt the indexing, as nn.ModuleDict expects strings (e.g. to "upconv{}{}".format(i, 0)).
Also, note that nn.DataParallel will split each element in your list in dim0, so that each GPU will receive a list of 5 tensors in the shape:

torch.Size([1, 64, 240, 320])
torch.Size([1, 64, 120, 160])
torch.Size([1, 128, 60, 80])
torch.Size([1, 256, 30, 40])
torch.Size([1, 512, 15, 20])

using 8 GPUs.

2 Likes

Thank you @ptrblck that worked like a charm!!