Difference in type between Input type and Weight type with ModuleList

Hello guys,
I have a problem with a list of modules. I have coded a module called BranchRoutingModule. I would like to create a list from this module.
I have the following code:

def _branch_routings(self):
    # structure = [nn.ModuleList([BranchRoutingModule(in_channels=self.in_channels) for j in range(int(pow(2, i)))]) for i in range(self.tree_height - 1)] # [[None], [None, None]] for tree height = 3
    structure = [[None for j in range(int(pow(2, i)))] for i in range(self.tree_height - 1)] # [[None], [None, None]] for tree height = 3

    cur = 0
    for i in range(self.tree_height - 1):
      for j in range(int(pow(2, i))):
        self.__setattr__('branch_routing_module' + str(cur), BranchRoutingModule(in_channels=self.in_channels))
        structure[i][j] = self.__getattr__('branch_routing_module' + str(cur))
        cur += 1
    return structure

I have first tried using nn.ModuleList (commented out at the top) but I get the following error: “Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same”

However if I use setattr and getattr, I get no errors and my model works fine.

Why is that? I don’t understand why setattr and getattr fix the problem.
I am using CUDA.

Thank you and regards,
Antoine

It seems you were trying to register a plain Python list containing nn.ModuleLists in your initial approach, which won’t work.
Make sure to register nn.Modules directly or nn.ModuleLists (not plain Python lists).

1 Like

Hi @ptrblck , thanks for your reply!
I have tried modifying my initial approach to:

structure = nn.ModuleList([nn.ModuleList([BranchRoutingModule(in_channels=self.in_channels) for j in range(int(pow(2, i)))]) for i in range(self.tree_height - 1)]) 

I still get the same errror: "“Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same”…
Running out of ideas here…

Just for extra precision, here is my branchrouting module:

class BranchRoutingModule(nn.Module):

  def __init__(self, in_channels=512, epsilon=1e-12):
    super(BranchRoutingModule, self).__init__()

    self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=1)
    self.l2norm = lambda x: torch.nn.functional.normalize(x, dim=1)
    self.fc = nn.Linear(in_channels, 1)
    self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) # size 1x1
    self.signedsqrt = lambda x: torch.sign(x) * torch.sqrt(torch.sign(x) * x + epsilon) 
    self.sigmoid = nn.Sigmoid()

  def forward(self, x):
    print("BR X : " + str(x.is_cuda))
    feature = self.conv1(x)
    feature = self.avgpool(feature)
    feature = self.signedsqrt(feature)
    feature = self.l2norm(feature)
    feature = feature.view(feature.size(0), -1)
    out = self.fc(feature)
    out = self.sigmoid(out)

    return out

Thanks!

Your code works fine:

class BranchRoutingModule(nn.Module):
  def __init__(self, in_channels=512, epsilon=1e-12):
    super(BranchRoutingModule, self).__init__()

    self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=1)
    self.l2norm = lambda x: torch.nn.functional.normalize(x, dim=1)
    self.fc = nn.Linear(in_channels, 1)
    self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) # size 1x1
    self.signedsqrt = lambda x: torch.sign(x) * torch.sqrt(torch.sign(x) * x + epsilon) 
    self.sigmoid = nn.Sigmoid()

  def forward(self, x):
    feature = self.conv1(x)
    feature = self.avgpool(feature)
    feature = self.signedsqrt(feature)
    feature = self.l2norm(feature)
    feature = feature.view(feature.size(0), -1)
    out = self.fc(feature)
    out = self.sigmoid(out)
    return out

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.structure = nn.ModuleList([nn.ModuleList([BranchRoutingModule() for j in range(10)]) for  i in range(5)])
        
    def forward(self, x):
        outs = []
        for modules in self.structure:
            for m in modules:
                out = m(x)
                print(out.shape)
                outs.append(out)
                
        outs = torch.cat(outs, dim=1)
        return outs
        
    
model = MyModel().cuda()
# make sure parameters are on the device
for name, param in model.named_parameters():
    print('{}, {}'.format(name, param.device))
x = torch.randn(2, 512, 24, 24).cuda()
out = model(x)
print(out.shape)
# > torch.Size([2, 50])

All parameters are pushed to the device and no errors are raised.

1 Like