Use of BatchNorm1d seems to prevent some model parameters from being updated

Looking for guidance around the use of BatchNorm1d and how its use affect parameter updating.

Are there are any guidelines on how to use BatchNorm1d?

The reason I’m asking is that it appears when I use BatchNorm1d, some model parameters are not updated during the backward pass.

Here is a minimal example that illustrates the issue:

import torch

INPUT_SIZE = 2
OUTPUT_SIZE = 2
BATCH_SIZE = 4
NUM_STEPS = 2


class SubModule2WithBN(torch.nn.Module):
    def __init__(self, size):
        super().__init__()
        self.fc = torch.nn.Linear(size, size)
        self.bn = torch.nn.BatchNorm1d(size)

    def forward(self, inputs):
        hidden = self.fc(inputs)
        return self.bn(hidden)


class SubModule2WithoutBN(torch.nn.Module):
    def __init__(self, size):
        super().__init__()
        self.fc = torch.nn.Linear(size, size)

    def forward(self, inputs):
        return self.fc(inputs)


class MyModel(torch.nn.Module):

    def __init__(self, input_size, output_size):
        super().__init__()

        self.input_size = input_size
        self.output_size = output_size

        self.fc1 = torch.nn.Linear(input_size, output_size)
        self.fc2 = torch.nn.Linear(output_size, output_size)

        self.submodule2 = torch.nn.ModuleList()

        for _ in range(NUM_STEPS):
            # Comment out one or the other of the following statements to run this minimal example
            self.submodule2.append(SubModule2WithBN(output_size))   # this will result in some parameters not updated
            # self.submodule2.append(SubModule2WithoutBN(output_size))  # this will result in all parameters updated

    def forward(self, inputs):
        hidden = self.fc1(inputs)
        for step in range(NUM_STEPS):
            hidden = self.submodule2[step](hidden)

        return self.fc2(hidden)


# generate synthetic tensor to feed into model
torch.random.manual_seed(1919)
input_tensor = torch.randn([BATCH_SIZE, INPUT_SIZE], dtype=torch.float32)

# create model instance
model = MyModel(INPUT_SIZE, OUTPUT_SIZE)
model.train(True)

# run synthetic input tensor through the model
model_output = model(input_tensor)

#
# Purpose of the following code is to test if all parameters were updated
# not to achieve an optimal solution.
#
# setup
loss_function = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
optimizer.zero_grad()

# generate psuedo target tensor used for calculating loss
target_tensor = torch.randn(model_output.shape, dtype=model_output.dtype)

# capture model parameters before doing parameter update pass
before = [(x[0], x[1].clone()) for x in model.named_parameters()]

# do update of model parameters
loss = loss_function(model_output, target_tensor)
loss.backward()
optimizer.step()

# capture model parameters after pass for updating parameters
after = [(x[0], x[1].clone()) for x in model.named_parameters()]

# confirm that at all parameters were updated
parameters_updated = []
parameter_update_results = []
for b, a in zip(before, after):
    parameter_updated = (a[1] != b[1]).any()
    parameters_updated.append(parameter_updated)
    parameter_update_results.append(
        f'\nParameter {a[0]} updated: {parameter_updated}{" <<<<<<<" if not parameter_updated else ""}\n'
        f'before model forward() pass (requires grad:{b[1].requires_grad}): \n\t{b[1]}\n'
        f'after model forward() pass (requires grad:{a[1].requires_grad}): \n\t{a[1]}\n'
        f'|before - after|: \n\t{torch.abs(b[1] - a[1])}\n'
    )
print(
    f'\nAll parameters updated: {all(parameters_updated)} {" <<<<<<<" if not all(parameters_updated) else ""}'
    f'{"".join(parameter_update_results)}'
)

In the above example code, if I use the class SubModule2WithBN, which contains Batchnorm1d, I see output like this:

All parameters updated: False  <<<<<<<
Parameter fc1.weight updated: True
before model forward() pass (requires grad:True): 
	tensor([[-0.5268,  0.5782],
        [-0.4338, -0.2834]], grad_fn=<CloneBackward0>)
after model forward() pass (requires grad:True): 
	tensor([[-0.5336,  0.5678],
        [-0.4505, -0.2665]], grad_fn=<CloneBackward0>)
|before - after|: 
	tensor([[0.0068, 0.0104],
        [0.0167, 0.0169]], grad_fn=<AbsBackward0>)

Parameter fc1.bias updated: False <<<<<<<
before model forward() pass (requires grad:True): 
	tensor([0.5317, 0.6025], grad_fn=<CloneBackward0>)
after model forward() pass (requires grad:True): 
	tensor([0.5317, 0.6025], grad_fn=<CloneBackward0>)
|before - after|: 
	tensor([0., 0.], grad_fn=<AbsBackward0>)
	
<<<<<<<<<<<<< deleted lines that showed parameters were updated >>>>>>>>>>>>>>>>>>>>

Parameter submodule2.1.fc.bias updated: False <<<<<<<
before model forward() pass (requires grad:True): 
	tensor([-0.5061,  0.0665], grad_fn=<CloneBackward0>)
after model forward() pass (requires grad:True): 
	tensor([-0.5061,  0.0665], grad_fn=<CloneBackward0>)
|before - after|: 
	tensor([0., 0.], grad_fn=<AbsBackward0>)

Parameter submodule2.1.bn.weight updated: True
before model forward() pass (requires grad:True): 
	tensor([1., 1.], grad_fn=<CloneBackward0>)
after model forward() pass (requires grad:True): 
	tensor([0.9230, 0.9826], grad_fn=<CloneBackward0>)
|before - after|: 
	tensor([0.0770, 0.0174], grad_fn=<AbsBackward0>)

Parameter submodule2.1.bn.bias updated: True
before model forward() pass (requires grad:True): 
	tensor([0., 0.], grad_fn=<CloneBackward0>)
after model forward() pass (requires grad:True): 
	tensor([-0.0509, -0.0126], grad_fn=<CloneBackward0>)
|before - after|: 
	tensor([0.0509, 0.0126], grad_fn=<AbsBackward0>)


Process finished with exit code 0

The above show two parameters not updated.

OTOH, if I use this class SubModule2WithoutBN, which does not have Batchnorm1d, then all parameters are updated:

All parameters updated: True 
Parameter fc1.weight updated: True
before model forward() pass (requires grad:True): 
	tensor([[-0.5268,  0.5782],
        [-0.4338, -0.2834]], grad_fn=<CloneBackward0>)
after model forward() pass (requires grad:True): 
	tensor([[-0.5212,  0.5742],
        [-0.4376, -0.2759]], grad_fn=<CloneBackward0>)
|before - after|: 
	tensor([[0.0057, 0.0040],
        [0.0038, 0.0075]], grad_fn=<AbsBackward0>)

Parameter fc1.bias updated: True               # <= this is now updated
before model forward() pass (requires grad:True): 
	tensor([0.5317, 0.6025], grad_fn=<CloneBackward0>)
after model forward() pass (requires grad:True): 
	tensor([0.5322, 0.5955], grad_fn=<CloneBackward0>)
|before - after|: 
	tensor([0.0005, 0.0070], grad_fn=<AbsBackward0>)

<<<<<<<<<<<<< deleted lines that showed parameters were updated >>>>>>>>>>>>>>>>>>>>

Parameter submodule2.1.fc.weight updated: True
before model forward() pass (requires grad:True): 
	tensor([[-0.3351, -0.6894],
        [ 0.2473, -0.2471]], grad_fn=<CloneBackward0>)
after model forward() pass (requires grad:True): 
	tensor([[-0.3261, -0.6792],
        [ 0.2027, -0.2713]], grad_fn=<CloneBackward0>)
|before - after|: 
	tensor([[0.0090, 0.0102],
        [0.0445, 0.0242]], grad_fn=<AbsBackward0>)

Parameter submodule2.1.fc.bias updated: True           #<= this is now updated
before model forward() pass (requires grad:True): 
	tensor([-0.5061,  0.0665], grad_fn=<CloneBackward0>)
after model forward() pass (requires grad:True): 
	tensor([-0.4968,  0.0194], grad_fn=<CloneBackward0>)
|before - after|: 
	tensor([0.0093, 0.0471], grad_fn=<AbsBackward0>)


Process finished with exit code 0

Some specific questions:

  • Am I using BatchNorm1d correctly?
  • Is this a reasonable way to test for updating of parameters?
#
# Purpose of the following code is to test if all parameters were updated
# not to achieve an optimal solution.
#
# setup
loss_function = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
optimizer.zero_grad()

# generate psuedo target tensor used for calculating loss
target_tensor = torch.randn(model_output.shape, dtype=model_output.dtype)

# capture model parameters before doing parameter update pass
before = [(x[0], x[1].clone()) for x in model.named_parameters()]

# do update of model parameters
loss = loss_function(model_output, target_tensor)
loss.backward()
optimizer.step()

# capture model parameters after pass for updating parameters
after = [(x[0], x[1].clone()) for x in model.named_parameters()]

Any guidance will be appreciated.

1 Like

Thanks for the great code and the interesting question!
Based on your output (and my runs) it seems that the bias of the preceding linear layer is not updated.
While the check claims so, you could check the gradients of the bias parameter and would see that these gradients are indeed really small:

model.fc1.bias.grad
> tensor([ 5.9605e-08, -1.7881e-07])

and thus the bias value is not changed.

Now, this might be surprising, but let’s think about the first operation of a batchnorm layer:

out = (x - mean) / stddev * weight + bias

As you can see, the mean of the incoming batch would be subtracted. If the previous layer has added a bias to the activation, it would be directly subtracted again.
Wouldn’t this also mean that this parameter has (almost) zero influence on the loss calculation (up to the limited numerical precision limit)?
The performance guide also mentions:

If a nn.Conv2d layer is directly followed by a nn.BatchNorm2d layer, then the bias in the convolution is not needed, instead use nn.Conv2d(..., bias=False, ....). Bias is not needed because in the first step BatchNorm subtracts the mean, which effectively cancels out the effect of bias.
This is also applicable to 1d and 3d convolutions as long as BatchNorm (or other normalization layer) normalizes on the same dimension as convolution’s bias.

However, this targets mainly the performance and would thus avoid the unnecessary add kernel.

CC @tom and @KFrank to correct me here in case my explanation is wrong. I’m also sure both can add interesting explanations from a more mathematical point of view. :wink:

3 Likes

As you might expect, @ptrblck 's explanation totally nails it, I’d probably strike the “almost” in almost zero influence. :slight_smile: The gradients should be 0 here as the result does not depend on the parameter. In other words, use bias=False for the linear/conv preceding batch norm.

If it is any consolation, it is #5 on Andrej Karpathy’s somewhat famous list of common NN training mistakes, so you are in good company.

Best regards

Thomas

2 Likes

@ptrblck @tom thank you very much for a quick and clear explanation.

When I use bias=False in the the preceding Linear layer, I see the indicator that all parameters have been updated.

I believe your response explains what I saw in a subsequent test I ran after posting my question. In this later test, instead of doing only one backward pass with parameter updating, I ran multiple iterations. In this test, I eventually saw the bias parameter updated after multiple iterations. Sometimes it took up to 200 iterations for an update to be detected.

In other words, it took many iterations for the the effect of the very small bias gradients to “accumulate” enough to have an effect on parameter updating.

Again, thank you for your insights and the link to PyTorch performance guide.

1 Like