Register layers within list as parameters

Due to some design choices, I need to have the pytorch layers within a list (along with other non-pytorch modules). Doing this makes the network un-trainable as the parameters are not picked up with they are within a list.

This is a dumbed down example.

from torch import nn
import numpy as np


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.modules = [nn.Linear(16, 8), np.multiply]


print(list(Net().parameters()))  # prints an empty list - expected to print weights of nn.Linear

How can I achieve what I expect?

One answer is to use nn.ModuleList. But this does not allow non-pytorch modules to be a part of it. Again, due to some design choices, I require both the pytorch and non-pytorch modules to be part of the same iterable.

I am not sure of your goal here.
Technically you could write a dummy torch.nn Module to hold non-pytorch modules, as below:

import torch
import torch.nn as nn
import numpy as np

class NonTorchContainer(nn.Module):
    def __init__(self, fn):
        self.fn = fn

    def forward(self,): pass

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.modules = nn.ModuleList([nn.Linear(16, 8), 
                                      NonTorchContainer(np.multiply)])

    def forward(self,): pass

if __name__ == '__main__':
    x = Net()

I hope you are aware that non-pytorch modules will break the computation graph and the gradients won’t be backpropagated.

1 Like

Although @InnovArul’s answer is perfectly OK, a more favourable solution for me was this.

I’ve solved it by just adding a line which creates a class variable which refers to the list element.

import numpy as np
import torch
from torch import nn


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.modules = [nn.Linear(3, 1), np.multiply]
        self.one = self.modules[0]  # Just have this line extra


model = Net()
print(list(model.parameters()))

print(model.modules[0].weight)  # View original weights
optim = torch.optim.SGD(model.parameters(), lr=0.1)
x = model.modules[0](torch.ones(1, 3) * 5) - 1.0  # Still using from self.modules
print(x)
x.backward()
print(model.modules[0].weight.grad)  # Grad WILL be calculated!
optim.step()
print(model.modules[0].weight)  # View updated weights