What is the expected behavior for lists contained in an nn.Module subclass ?
Refer to example code below:
- model = model.cuda() fails with type mismatch error
- If the elements of list are converted to cuda within init(), forward() seems to work but backward() fails
from __future__ import print_function
import torch
import torch.nn as nn
from torch.autograd import Variable
class MyLinear(nn.Module):
def __init__(self):
super(MyLinear, self).__init__()
self.lin = nn.Linear(2, 4)
# self.lin_list = [nn.Linear(4,2), nn.Linear(4,2)] # CASE 1: error for mixing cuda and non-cuda types
self.lin_list = [nn.Linear(4,2).cuda(), nn.Linear(4,2).cuda()] # CASE 2: forward() works, but backward() doesn't
def forward(self, x):
out = self.lin(x)
out1 = self.lin_list[0](out)
out2 = self.lin_list[1](out)
out = torch.cat((out1, out2))
return out
if __name__ == '__main__':
x = Variable(torch.FloatTensor(2).random_().cuda())
y = Variable(torch.FloatTensor(4).random_().cuda())
print(x)
print(y)
model = MyLinear()
model = model.cuda()
out = model(x)
print(out)
print(list(model.parameters()))
model.zero_grad()
out.backward(y)
print(list(model.parameters()))