DataParallel donnot support customed layer

Hi, I have got a problem abount multi-gpus training. The code is from https://github.com/chaoqichen/HTCN. Here is the key code:

class RandomLayer(nn.Module):

    def __init__(self, input_dim_list=[], output_dim=1024):
        super(RandomLayer, self).__init__()
        self.input_num = len(input_dim_list)# 2
        self.output_dim = output_dim
        self.random_matrix = [torch.rand(input_dim_list[i], output_dim) for i in range(self.input_num)]

    def forward(self, input_list):
        return_list = [torch.mm(input_list[i], self.random_matrix[i]) for i in range(self.input_num)]
        return_tensor = return_list[0] / math.pow(float(self.output_dim), 1.0 / len(return_list))
        for single in return_list[1:]:
            return_tensor = torch.mul(return_tensor, single)
        return return_tensor

    def cuda(self):
        super(RandomLayer, self).cuda()
        self.random_matrix = [val.cuda() for val in self.random_matrix]

class vgg16(_fasterRCNN):

  def __init__(self, classes, pretrained=False, class_agnostic=False,lc=False,gc=False, la_attention = False, mid_attention = False):
    self.model_path = cfg.VGG_PATH
    self.dout_base_model = 512
    self.pretrained = pretrained
    self.class_agnostic = class_agnostic
    self.lc = lc
    self.gc = gc

    _fasterRCNN.__init__(self, classes, class_agnostic,lc,gc, la_attention, mid_attention)

  def _init_modules(self):
    vgg = models.vgg16()
    if self.pretrained:
        print("Loading pretrained weights from %s" %(self.model_path))
        state_dict = torch.load(self.model_path)
        vgg.load_state_dict({k:v for k,v in state_dict.items() if k in vgg.state_dict()})

    vgg.classifier = nn.Sequential(*list(vgg.classifier._modules.values())[:-1])

    # not using the last maxpool layer
    #print(vgg.features)
    self.RCNN_base1 = nn.Sequential(*list(vgg.features._modules.values())[:14])
    self.RCNN_base2 = nn.Sequential(*list(vgg.features._modules.values())[14:21])
    self.RCNN_base3 = nn.Sequential(*list(vgg.features._modules.values())[21:-1])
    #print(self.RCNN_base1)
    #print(self.RCNN_base2)
    self.netD = netD(context=self.gc)
    self.netD_pixel = netD_pixel(context=self.lc)
    self.netD_mid = netD_mid(context=self.gc)
    feat_d = 4096
    feat_d2 = 384
    feat_d3 = 2048

    self.RandomLayer = RandomLayer([feat_d, feat_d2], feat_d3)
    self.RandomLayer.cuda()
    # Fix the layers before conv3:
    self.netD_da = netD_da(feat_d3)

    for layer in range(10):
      for p in self.RCNN_base1[layer].parameters(): p.requires_grad = False

    # self.RCNN_base = _RCNN_base(vgg.features, self.classes, self.dout_base_model)

    self.RCNN_top = vgg.classifier

    self.RCNN_cls_score = nn.Linear(feat_d+feat_d2, self.n_classes)
    if self.class_agnostic:
        self.RCNN_bbox_pred = nn.Linear(feat_d+feat_d2, 4)
    else:
        self.RCNN_bbox_pred = nn.Linear(feat_d+feat_d2, 4 * self.n_classes)


  def _head_to_tail(self, pool5):
    
    pool5_flat = pool5.view(pool5.size(0), -1)
    fc7 = self.RCNN_top(pool5_flat)

    return fc7

RandomLayer cannot be assigned to multi devices using the .cuda() method using the code self.RandomLayer.cuda(). The error message is as follows:

return_list = [torch.mm(input_list[i], self.random_matrix[i]) for i in range(self.input_num)]

RuntimeError: Expected tensor for ‘out’ to have the same device as tensor for argument #3 ‘mat2’; but device 1 dose not equal 0(while checking arguments for addmm)

And I print the device of input_list[i] and self.random_matrix[i] when using two gpus, the result is as follows (self.input_num == 2):

gpu id 0:
input_list0: cuda:0 self.random_matrix0: cuda:0
input_list1: cuda:0 self.random_matrix1: cuda:0
gpu id 1:
input_list0: cuda:1 self.random_matrix0: cuda:0
input_list1: cuda:1 self.random_matrix1: cuda:0

Is there any advice to fix the problem?

I changed the ‘RandomLayer’ code and it’s working with 2 gpus, but still failed with 3 or 4 gpus. I donnot know why, so could you please have a loot at my question? In https://github.com/pytorch/pytorch/issues/50087,@ngimel told me that you cannot override .cuda() method this way, it’s not supported, you need to override _apply method instead. But I don’t know how to implemented it.
@ptrblck
Here is my modified RandomLyaer codes, because len(self.random_matrix) is actually 2 in my project, so I simply split ‘self.random_matrix[2]’ to 2 scalars. 2 gpus is working, 3 or more gpus will be failed.

class RandomLayer(nn.Module):
    def __init__(self, batch_size, input_dim_list=[], output_dim=1024):
        super(RandomLayer, self).__init__()
        self.input_num = len(input_dim_list)# 2
        self.output_dim = output_dim
        self.random_matrix0 = torch.nn.Parameter(torch.rand(input_dim_list[0], output_dim))
        self.random_matrix1 = torch.nn.Parameter(torch.rand(input_dim_list[1], output_dim))

    def forward(self, input_list):
        assert(len(input_list) == 2)
        return_list0 = torch.mm(input_list[0], self.random_matrix0)
        return_list1 = torch.mm(input_list[1], self.random_matrix1)
        return_tensor = return_list0 / math.pow(float(self.output_dim), 1.0 / 2)
        return_tensor = torch.mul(return_tensor, return_list1)
        return return_tensor

I’m not sure why you want to override the cuda method, but guess you are running into another device mismatch with self.random_matrix, since it’s initialized as a plain list, and tried to fix it with the custom cuda method.

To write device-agnostic code remove all cuda() calls and make sure all modules, parameters, and buffers are properly registered.
E.g. self.random_matrix won’t register the tensors and you would either have to use e.g. nn.ParameterList or register all tensors via self.register_buffer.

This will make sure that calling random_layer.cuda() pushes all internal submodules etc. to the defined device (which is used in nn.DataParallel). The updated module looks generally alright and works fine in this code snippet on 8 GPUs:

class RandomLayer(nn.Module):
    def __init__(self, batch_size, input_dim_list=[], output_dim=1024):
        super(RandomLayer, self).__init__()
        self.input_num = len(input_dim_list)# 2
        self.output_dim = output_dim
        self.random_matrix0 = torch.nn.Parameter(torch.rand(input_dim_list[0], output_dim))
        self.random_matrix1 = torch.nn.Parameter(torch.rand(input_dim_list[1], output_dim))

    def forward(self, x1, x2):
        return_list0 = torch.mm(x1, self.random_matrix0)
        return_list1 = torch.mm(x2, self.random_matrix1)
        return_tensor = return_list0 / math.pow(float(self.output_dim), 1.0 / 2)
        return_tensor = torch.mul(return_tensor, return_list1)
        return return_tensor


class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.rand = RandomLayer(1, [1, 1])

    def forward(self, x1, x2):
        print(x1.device)
        x = self.rand(x1, x2)
        return x

model = MyModule()
model = nn.DataParallel(model).cuda()
x1 = torch.randn(8, 1).cuda()
x2 = torch.randn(8, 1).cuda()
out = model(x1, x2)
print(out.device)

Note that nn.DataParallel will split the input batch in dim0 and send each chunk to the corresponding GPU. In your current code snippet you are trying to index input_list in dim0, which also seems to be wrong, as it would index it in the (split) batch dimension.
You could thus either pass multiple inputs (as done in my code snippet) or stack and index the input in another dimension.

Hi ,ptrblck. Thank you for your kindly reply. I try to use nn.ParameterList in multi-gpus mode in your coe snippet, but it still doesn’t work. Here is the source code named ‘dp.py’, single gpu is OK.

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import torch
import torch.nn as nn
import math

feat_d = 16
feat_d2 = 8
feat_d3 = 8

class RandomLayer(nn.Module):
    def __init__(self, input_dim_list=[], output_dim=1024):
        super(RandomLayer, self).__init__()
        self.input_num = len(input_dim_list)# 2
        self.output_dim = output_dim
        self.random_matrix = torch.nn.ParameterList([torch.nn.Parameter(torch.rand(input_dim_list[i], output_dim)) for i in range(self.input_num)])
        print('*****')
        print(len(self.random_matrix))
        print(self.random_matrix)

    def forward(self, input_list):
        print('&&&&&')
        print(len(self.random_matrix))
        print(self.random_matrix)
        return_list = [torch.mm(input_list[i], self.random_matrix[i]) for i in range(self.input_num)]
        return_tensor = return_list[0] / math.pow(float(self.output_dim), 1.0 / 2)
        for single in return_list[1:]:
            return_tensor = torch.mul(return_tensor, single)
        return return_tensor


class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.rand = RandomLayer([feat_d, feat_d2], feat_d3)

    def forward(self, x1, x2):
        print(x1.device)
        x = self.rand([x1, x2])
        return x

model = MyModule()
model = nn.DataParallel(model).cuda()
x1 = torch.randn(4, feat_d).cuda()
x2 = torch.randn(4, feat_d2).cuda()
out = model(x1, x2)
print(out.device)

when I running that code with the command ‘CUDA_VISIBLE_DEVICES=4,5 python dp.py’, ‘len(self.random_matrix)’ is 0. I don’t know why, so could you have a look at this?

len(self.random_matrix) is 0, since you are not passing input_dim_list and use the default value, which is an empty list.
This also means that self.input_num = len(input_dim_list) will be 0 and thus:

self.random_matrix = torch.nn.ParameterList([torch.nn.Parameter(torch.rand(input_dim_list[i], output_dim)) for i in range(self.input_num)])

will be empty.

Hi, ptrblck.
I think I have passed a list to input_dim_list in that code which is as follows:
self.rand = RandomLayer([feat_d, feat_d2], feat_d3). And that code could work in single gpu, giving len(self.random_matrix) equals to 2 correctly. But 2 gpu will give a wrong result, with len(self.random_matrix) equals to 0.
And I found that in 2 gpus mode, in the class RandomLayer __init__ method, len(self.random_matrix) equals to 2. However, in RandomLayer forward method, len(self.random_matrix equals to 0.

You are right and I missed the constructor inside MyModule.
You should get a warning using nn.ParameterList and nn.DataParallel, which is not supported in this setup, so you would need to fallback to registering the parameters separately:

UserWarning: nn.ParameterList is being used with DataParallel but this is not supported. This list will appear empty for the models replicated on each GPU except the original one.