Due to some design choices, I need to have the pytorch layers within a list (along with other non-pytorch modules). Doing this makes the network un-trainable as the parameters are not picked up with they are within a list.
This is a dumbed down example.
from torch import nn
import numpy as np
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.modules = [nn.Linear(16, 8), np.multiply]
print(list(Net().parameters())) # prints an empty list - expected to print weights of nn.Linear
One answer is to use nn.ModuleList. But this does not allow non-pytorch modules to be a part of it. Again, due to some design choices, I require both the pytorch and non-pytorch modules to be part of the same iterable.
Although @InnovArul’s answer is perfectly OK, a more favourable solution for me was this.
I’ve solved it by just adding a line which creates a class variable which refers to the list element.
import numpy as np
import torch
from torch import nn
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.modules = [nn.Linear(3, 1), np.multiply]
self.one = self.modules[0] # Just have this line extra
model = Net()
print(list(model.parameters()))
print(model.modules[0].weight) # View original weights
optim = torch.optim.SGD(model.parameters(), lr=0.1)
x = model.modules[0](torch.ones(1, 3) * 5) - 1.0 # Still using from self.modules
print(x)
x.backward()
print(model.modules[0].weight.grad) # Grad WILL be calculated!
optim.step()
print(model.modules[0].weight) # View updated weights