weights not passed to optimiser are trained when nn.ModuleList is used

when I run a regular model, and pass only one layer to the optimiser, it behaves as expected and only one layer is trained. but when training a model that contains nn.ModuleList, even parameters not passed to the optimiser are trained.


    class testModel(nn.Module):
      def __init__(self):
        super(testModel, self).__init__()
        self.layers = nn.ModuleList([nn.Linear(3,3)]*2)
      def forward(self, x):
        x = self.layers[0](x)
        x = F.relu(x)
        x = self.layers[1](x)  
        return x
    
    t = testModel()
    x = torch.ones((1,3)) * -1
    target = torch.ones((1,3)) * 200
    
    w1_b4_train= t.layers[0].weight.data.numpy().copy()
    w2_b4_train= t.layers[1].weight.data.numpy().copy()
    
    opt = optim.Adam(t.layers[1].parameters(), lr = 0.01)
    loss = nn.MSELoss()
    
    for i in range(10):
      res = t(x)
      opt.zero_grad()
      l = loss(res, target)
      l.backward()
      opt.step()
    
    w1_after_train = t.layers[0].weight.data.numpy().copy()
    w2_after_train = t.layers[1].weight.data.numpy().copy()
    
    print(np.array_equal(w1_b4_train,w1_after_train))
    print(np.array_equal(w2_b4_train,w2_after_train))

this code returns:

    False
    False

why is that? is there a way to fix this?

You are creating references in:

[nn.Linear(3, 3)]*2

such that you are in fact initializing a single linear layer and are just repeating it.
Use [nn.Linear(3, 3) for _ in range(2)] to create new objects.
Here is a small example:

# use references
layers = nn.ModuleList([nn.Linear(3,3)]*2)
# parameters are equal
print(layers[0].weight)
print(layers[1].weight)

# inplace manipulations changes both parameters
with torch.no_grad():
    layers[0].weight.fill_(0.)
print(layers[0].weight)
print(layers[1].weight)

# create new modules
layers = nn.ModuleList([nn.Linear(3,3) for _ in range(2)])
# different
print(layers[0].weight)
print(layers[1].weight)

with torch.no_grad():
    layers[0].weight.fill_(0.)
print(layers[0].weight)
print(layers[1].weight)