Module containg lists

What is the expected behavior for lists contained in an nn.Module subclass ?

Refer to example code below:

  1. model = model.cuda() fails with type mismatch error
  2. If the elements of list are converted to cuda within init(), forward() seems to work but backward() fails

from __future__ import print_function
import torch
import torch.nn as nn
from torch.autograd import Variable 

class MyLinear(nn.Module):

    def __init__(self):
        super(MyLinear, self).__init__()
        self.lin = nn.Linear(2, 4)
        # self.lin_list = [nn.Linear(4,2), nn.Linear(4,2)] # CASE 1: error for mixing cuda and non-cuda types 
        self.lin_list = [nn.Linear(4,2).cuda(), nn.Linear(4,2).cuda()]  # CASE 2:  forward() works, but backward() doesn't 

    def forward(self, x):
        out = self.lin(x)
        out1 = self.lin_list[0](out)
        out2 = self.lin_list[1](out)
        out = torch.cat((out1, out2))
        return out

if __name__ == '__main__':

    x = Variable(torch.FloatTensor(2).random_().cuda())
    y = Variable(torch.FloatTensor(4).random_().cuda())
    print(x)
    print(y)
    model = MyLinear()
    model = model.cuda()

    out = model(x)
    print(out)

    print(list(model.parameters()))

    model.zero_grad()
    out.backward(y)

    print(list(model.parameters()))

Hi,

They should be inside a ModuleList module to be detected properly, see doc here.

3 Likes