Weights not updated in a custom module

I am facing a weird issue when trying to optimize parameters in a custom module. In particular, the weights of certain parameters are not updated, even if their gradients are computed.

My CustomModule is defined as:

  • A list of Batch Normalization (BN) layers, in which, only one is chosen for the forward/backward pass. That is, the model will learn different gammas/betas (based on the selected BN layer).

  • A fully-connected layer.

Of course, this structure, defined as such, does not make sense. I just wanted to focus on the occurred issue. Below, the structure of the CustomModule.

class CustomModule(nn.Module):
    def __init__(self, num_batch_norm):
        super(CustomModule, self).__init__()

        self.batch_norm_list = [nn.BatchNorm1d(10) for _ in range(num_batch_norm)]
        self.fc = nn.Linear(10, 2)

    def forward(self, x, num_bn):
        x = self.batch_norm_list[num_bn](x)
        x = self.fc(x)

        return x

I create a model (instance of CustomModule) with three BN layers and SGD optimizer. After, I define test variables x and y to perform a single pass through the model (including the backward pass).

torch.manual_seed(0)

# Create a 'CustomModule' with three BN layers.
model = CustomModule(3)

optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)

# Define test variables: the input (x) and the target (y).
x = torch.rand(5, 10)
y = torch.rand(5, 2)

# Pass 'x' through the 2nd (index=1) BN layer.
out = model(x, 1)

loss = torch.linalg.norm(out - y)

model.zero_grad()
loss.backward()
optimizer.step()

I check that the gradients for the 2nd BN layer (chosen in the forward pass) are computed.

# Check the weights' gradients in the 2nd BN layer.
model.batch_norm_list[1].weight.grad
# tensor([ 0.0810,  0.0052,  0.2871,  0.3091,  0.2437,  0.0916,  0.1410,  0.4406,
#         -0.0288,  0.0740])

However, when I check the weights (gammas) after optimization, I noticed that they are not updated (in batch normalization, gammas (weights) are initialized by one).

# Check the weights in the 2nd BN layer (after optimization)
model.batch_norm_list[1].weight
# Parameter containing:
# tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], requires_grad=True)

Having the initial weights (a tensor of ones) and gradients, I expected the updated weights to be like below (using the SGD optimizer).

# Expected values
weights = model.batch_norm_list[1].weight
gradients = model.batch_norm_list[1].weight.grad
lr = 1e-2

expected_weights = weights - lr * gradients
# tensor([0.9992, 0.9999, 0.9971, 0.9969, 0.9976, 0.9991, 0.9986, 0.9956, 1.0003,
#         0.9993], grad_fn=<SubBackward0>)

So, I want to know why this issue encounters and how to solve it.

model.zero_grad()

I think the line should be modified to optimizer.zero_grad()

1 Like

You are storing the weights in a conventional list which is not tracked by the optimizer.
You should use a pytorvh list.

1 Like

@Sangyeon_Kim: Yes, you are right, thanks for the remark.

@JuanFMontesinos: You mean that I have to convert self.batch_norm_list (which holds the BatchNorm layers) into a PyTorch Tensor. But how can I do this? Using torch.tensor(self.batch_norm_list) doesn’t work due to unrecognized element type (in this case, BatchNorm), below is the error message.

<ipython-input-370-e40c7345c9d8> in __init__(self, num_batch_norm)
      4 
      5         self.batch_norm_list = [nn.BatchNorm1d(10) for _ in range(num_batch_norm)]
----> 6         self.batch_norm_list = torch.tensor(self.batch_norm_list)
      7 
      8         self.fc = nn.Linear(10, 2)

RuntimeError: Could not infer dtype of BatchNorm1d

Sorry I should have been more explicit.
You have to use ModuleList — PyTorch 1.8.1 documentation which is a list-like module which is properly tracked by the engine. Think that attribute of a nn.Module distinct from a tensor or another module is not tracked by the engine.

1 Like

It works! I didn’t know the ModuleList.
Thank you very much.