Performance highly degraded when eval() is activated in the test phase

@Valerio_Biscione Thank you!! fixed my low accuracy on eval, when using a smaller batch size.
In my case I didn’t have direct access to the model class, so I couldn’t initialize batch norm with track_running_stats=False. As you rightly mentioned the latest commit checks the batch norm stats to decide if it is in training mode or eval mode, so I set the mean and var variables in the batch norm to None and it worked out perfectly.

for m in model.modules():
    for child in m.children():
        if type(child) == nn.BatchNorm2d:
            child.track_running_stats = False
            child.running_mean = None
            child.running_var = None

model.eval()
4 Likes

I have encountered the same problem.
Simply say, the model seems trained well and the loss is as expected during training.
While testing, after setting model.eval(), the result seems bad and the loss is high.
Using model.train(), or set m.track_running_stats = False really improve the result, however, if i evaluate the model with batch_size=1, the result is bad again.
Then i check my code and find the batch_norm layer’s affine is set False, I think probably it causes the problem. Now I am retraining my model with affine=True. I will report the result.
BTW, if you encounter the same problem, and your batch_norm 's affine is False, I think it may be the reason.

I have same problem.
I try to overfit and train only one picture and validate it at the same time. The output of training is very good, the output of validation is bad. Then I set:
cudnn.deterministic = True
It worked. I speculated that the cudnnbatchmark algorithm causes the result of each forward to be different.

Thanks for your answer! This works for me! It has to be False when creating it, i.e. BatchNorm(…, track_running_stats=False).

This resolves my issue. I just applied this during the test time, it performs well. Thank you.

I am new with pytorch, could you tell me where exactly wrote these lines?
thank you

@ptrblck I have read carefully, your explanations regarding the behavior of the BatchNorm2d in training and in eval mode. I am doing an experiment in which I try to overfit a HRNet model on only two images. So basically in the training loop, I have one batch of 2 images. In train mode the BatchNorm2d acts as it is supposed to and the loss is in the expected range. When I am running the validation on the same batch as in train with eval mode on, the loss is exploding. I expect this behavior to be normal if the momentum is set to values in the range of (0, 1), I am setting the momentum to 1 in order for the running_mean, and running_var to be the same as the last batch, i.e, the only batch that I am training and validating on. Can you tell me what am I missing, I expected the train and validation loss to be near the same value, and the running_mean and running_var to be the same on each epoch but this does not happen, as the loss is exploding and the running_mean and running_var has different values on each epoch.

I don’t know as I see the expected behavior:

bn = nn.BatchNorm2d(3, momentum=1.)

x = torch.randn(64, 3, 224, 224) * 25. + 57.
print(x.mean([0, 2, 3]))
# tensor([57.0125, 57.0183, 57.0110])
print(x.var([0, 2, 3], unbiased=False))
# tensor([624.4194, 626.0566, 624.7243])

print(bn.running_mean)
# tensor([0., 0., 0.])
print(bn.running_var)
# tensor([1., 1., 1.])

out = bn(x)

print(bn.running_mean)
# tensor([57.0125, 57.0183, 57.0110])
print(bn.running_var)
# tensor([624.4196, 626.0568, 624.7245])
import torch
import torch.nn as nn

# Define the CNN model with BatchNorm2D
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(16, momentum=1)
        self.relu = nn.ReLU()
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(16 * 480 * 848, 2)
    
    def forward(self, x):
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.flatten(x)
        x = self.fc(x)
        return x

# Generate two constant random tensors
tensor1 = torch.randn(2, 3, 480, 848)
tensor2 = torch.randn(2, 3, 480, 848)

# Instantiate the model
model = SimpleCNN()

# Set up optimizer and loss function
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
loss_fn = nn.CrossEntropyLoss()

# L2 regularization strength
l2_lambda = 0.01

# Concatenate tensors and adjust labels
inputs = torch.cat((tensor1, tensor2), dim=0)
labels = torch.tensor([0, 1, 0, 1])  # Adjusted labels for the batch

# Set the number of epochs
num_epochs = 10

# Training loop for multiple epochs
for epoch in range(num_epochs):
    model.train()
    optimizer.zero_grad()
    
    # Forward pass
    outputs = model(inputs)
    
    # Compute cross-entropy loss
    loss = loss_fn(outputs, labels)
    
    # L2 regularization term
    l2_reg = torch.tensor(0.)
    for param in model.parameters():
        l2_reg += torch.norm(param, p=2)  # L2 norm
        
    # Compute the combined loss
    total_loss = loss + l2_lambda * l2_reg
    
    # Backward pass and optimization
    total_loss.backward()
    optimizer.step()

    # Print running_mean and running_var
    print(f"Epoch {epoch+1} - Running Mean:")
    print(model.bn1.running_mean)
    
    print(f"Epoch {epoch+1} - Running Variance:")
    print(model.bn1.running_var)

    with torch.no_grad():
    
        # Validation loop
        model.eval()
        print(f"Epoch {epoch+1} Loss {total_loss}")

For example I have generated this dummy code the outputs are:

Epoch 1 - Running Mean:
tensor([-0.0828,  0.0310, -0.1817, -0.1875,  0.0570, -0.1008,  0.1894,  0.1808,
         0.1177,  0.0498,  0.1221,  0.0795,  0.1806,  0.1176, -0.0843,  0.1756])
Epoch 1 - Running Variance:
tensor([0.3489, 0.2866, 0.3474, 0.3802, 0.2384, 0.3030, 0.2698, 0.4196, 0.3284,
        0.3617, 0.4089, 0.3673, 0.4466, 0.3379, 0.3950, 0.3943])
Epoch 1 Loss 0.6932005882263184
Epoch 2 - Running Mean:
tensor([-0.0828,  0.0310, -0.1817, -0.1875,  0.0569, -0.1008,  0.1894,  0.1807,
         0.1177,  0.0498,  0.1221,  0.0795,  0.1806,  0.1175, -0.0843,  0.1755])
Epoch 2 - Running Variance:
tensor([0.3489, 0.2866, 0.3474, 0.3802, 0.2384, 0.3030, 0.2698, 0.4196, 0.3284,
        0.3616, 0.4089, 0.3673, 0.4465, 0.3379, 0.3950, 0.3942])
Epoch 2 Loss 0.11885640770196915
Epoch 3 - Running Mean:
tensor([-0.0827,  0.0310, -0.1817, -0.1875,  0.0569, -0.1008,  0.1894,  0.1807,
         0.1177,  0.0497,  0.1220,  0.0795,  0.1806,  0.1175, -0.0843,  0.1755])
Epoch 3 - Running Variance:
tensor([0.3488, 0.2866, 0.3474, 0.3802, 0.2383, 0.3030, 0.2698, 0.4196, 0.3283,
        0.3616, 0.4089, 0.3673, 0.4465, 0.3379, 0.3950, 0.3942])
Epoch 3 Loss 0.11885036528110504
Epoch 4 - Running Mean:
tensor([-0.0827,  0.0310, -0.1816, -0.1874,  0.0569, -0.1007,  0.1893,  0.1806,
         0.1177,  0.0497,  0.1220,  0.0795,  0.1805,  0.1175, -0.0842,  0.1755])
Epoch 4 - Running Variance:
tensor([0.3488, 0.2866, 0.3473, 0.3801, 0.2383, 0.3030, 0.2698, 0.4195, 0.3283,
        0.3616, 0.4088, 0.3673, 0.4465, 0.3378, 0.3949, 0.3942])
Epoch 4 Loss 0.11884436011314392
Epoch 5 - Running Mean:
tensor([-0.0827,  0.0310, -0.1816, -0.1874,  0.0569, -0.1007,  0.1893,  0.1806,
         0.1176,  0.0497,  0.1220,  0.0795,  0.1805,  0.1175, -0.0842,  0.1754])
Epoch 5 - Running Variance:
tensor([0.3488, 0.2865, 0.3473, 0.3801, 0.2383, 0.3029, 0.2698, 0.4195, 0.3283,
        0.3615, 0.4088, 0.3672, 0.4464, 0.3378, 0.3949, 0.3941])
Epoch 5 Loss 0.11883837729692459
Epoch 6 - Running Mean:
tensor([-0.0827,  0.0310, -0.1816, -0.1873,  0.0569, -0.1007,  0.1893,  0.1806,
         0.1176,  0.0497,  0.1220,  0.0795,  0.1805,  0.1174, -0.0842,  0.1754])
Epoch 6 - Running Variance:
tensor([0.3487, 0.2865, 0.3473, 0.3801, 0.2383, 0.3029, 0.2697, 0.4195, 0.3283,
        0.3615, 0.4088, 0.3672, 0.4464, 0.3378, 0.3949, 0.3941])
Epoch 6 Loss 0.11883237212896347
Epoch 7 - Running Mean:
tensor([-0.0827,  0.0310, -0.1815, -0.1873,  0.0569, -0.1007,  0.1892,  0.1805,
         0.1176,  0.0497,  0.1220,  0.0794,  0.1804,  0.1174, -0.0842,  0.1754])
Epoch 7 - Running Variance:
tensor([0.3487, 0.2865, 0.3473, 0.3800, 0.2383, 0.3029, 0.2697, 0.4194, 0.3282,
        0.3615, 0.4087, 0.3672, 0.4464, 0.3378, 0.3948, 0.3941])
Epoch 7 Loss 0.11882635205984116
Epoch 8 - Running Mean:
tensor([-0.0827,  0.0310, -0.1815, -0.1873,  0.0569, -0.1007,  0.1892,  0.1805,
         0.1176,  0.0497,  0.1219,  0.0794,  0.1804,  0.1174, -0.0842,  0.1753])
Epoch 8 - Running Variance:
tensor([0.3487, 0.2865, 0.3472, 0.3800, 0.2382, 0.3029, 0.2697, 0.4194, 0.3282,
        0.3615, 0.4087, 0.3671, 0.4463, 0.3377, 0.3948, 0.3940])
Epoch 8 Loss 0.11882033944129944
Epoch 9 - Running Mean:
tensor([-0.0826,  0.0310, -0.1815, -0.1872,  0.0569, -0.1006,  0.1892,  0.1805,
         0.1176,  0.0497,  0.1219,  0.0794,  0.1804,  0.1174, -0.0842,  0.1753])
Epoch 9 - Running Variance:
tensor([0.3487, 0.2864, 0.3472, 0.3800, 0.2382, 0.3028, 0.2697, 0.4193, 0.3282,
        0.3614, 0.4087, 0.3671, 0.4463, 0.3377, 0.3948, 0.3940])
Epoch 9 Loss 0.11881443113088608
Epoch 10 - Running Mean:
tensor([-0.0826,  0.0310, -0.1814, -0.1872,  0.0569, -0.1006,  0.1891,  0.1804,
         0.1175,  0.0497,  0.1219,  0.0794,  0.1803,  0.1174, -0.0841,  0.1753])
Epoch 10 - Running Variance:
tensor([0.3486, 0.2864, 0.3472, 0.3799, 0.2382, 0.3028, 0.2696, 0.4193, 0.3282,
        0.3614, 0.4086, 0.3671, 0.4462, 0.3377, 0.3947, 0.3940])
Epoch 10 Loss 0.11881034821271896```

You can see that the values slightly change. This being a small network it might not influence it as much but when talking about a huge one for example my HRNet I think that those small changes can accumulate and make the loss explode.

I doubt that small changes in the range 1e-4 are responsible for your loss explosion, as this would already come close to the expected abs. error for float32 assuming a different order of operations is used as seen in this small example:

x = torch.randn(100, 100, 100)
s1 = x.sum()
s2 = x.sum(0).sum(0).sum(0)
print((s1 - s2).abs())
# tensor(9.1553e-05)

I have printed the running_mean of the first batchNorm layer of the HRNet model here are the results on the first few epochs:

tensor([-2.8645e-04, -7.3611e-04,  4.6899e-04, -5.0444e-04, -1.9839e-04,
        -4.0576e-04, -1.3211e-03,  3.9516e-04, -1.0368e-03,  3.0769e-04,
        -1.1822e-03,  2.8824e-05,  2.0874e-04,  5.2676e-04, -4.5067e-04,
        -2.3251e-04,  5.6170e-04, -5.9286e-04,  5.8100e-04, -7.5048e-04,
        -6.2080e-05,  4.6286e-04, -7.8213e-04,  2.1536e-04,  1.0677e-03,
        -1.3129e-04, -3.4901e-04, -1.4472e-04,  4.5749e-04, -2.2343e-04,
        -4.3963e-04, -1.0693e-04,  4.3047e-04,  3.1450e-04, -2.0459e-04,
         5.7683e-04,  2.6893e-04, -1.3746e-03, -6.7454e-05, -3.2940e-04,
         6.1982e-05, -1.7527e-04, -9.9344e-05,  9.0883e-04,  1.1961e-05,
         2.4231e-04,  2.4875e-04,  1.3527e-03, -7.4792e-04, -1.3124e-04,
        -2.1075e-04, -6.0424e-05, -9.3601e-04,  1.0437e-04,  1.1708e-03,
         2.5339e-04,  1.0223e-03, -4.9711e-04,  4.3105e-04, -1.4752e-05,
        -1.5058e-03,  1.4785e-04, -2.3201e-04, -1.6655e-04], device='cuda:0')
tensor([ 0.0275, -0.0285, -0.0038, -0.0157, -0.0280, -0.0282, -0.0189, -0.0248,
        -0.0050,  0.0277,  0.0045,  0.0274, -0.0276, -0.0272, -0.0018, -0.0090,
        -0.0272,  0.0155, -0.0253, -0.0285,  0.0277,  0.0282, -0.0285, -0.0176,
        -0.0121, -0.0268,  0.0112, -0.0213, -0.0246, -0.0222,  0.0164, -0.0279,
         0.0220,  0.0195,  0.0091,  0.0116,  0.0280, -0.0164, -0.0217, -0.0281,
        -0.0254, -0.0221, -0.0275, -0.0211, -0.0270,  0.0020,  0.0280, -0.0169,
         0.0165,  0.0042,  0.0199, -0.0248, -0.0052,  0.0207,  0.0163, -0.0275,
         0.0223,  0.0092,  0.0047, -0.0278, -0.0018,  0.0279,  0.0261, -0.0243],
       device='cuda:0')
tensor([ 0.0459, -0.0457,  0.0039, -0.0376, -0.0465, -0.0484, -0.0242, -0.0418,
        -0.0171,  0.0464, -0.0020,  0.0480, -0.0457, -0.0460,  0.0195, -0.0320,
        -0.0462,  0.0242, -0.0414, -0.0440,  0.0462,  0.0485, -0.0468, -0.0420,
        -0.0290, -0.0443,  0.0191, -0.0328, -0.0458, -0.0342,  0.0249, -0.0462,
         0.0348,  0.0361,  0.0106,  0.0217,  0.0469, -0.0186, -0.0327, -0.0469,
        -0.0417, -0.0410, -0.0471, -0.0298, -0.0450, -0.0100,  0.0466, -0.0394,
         0.0216,  0.0242,  0.0329, -0.0276, -0.0037,  0.0428,  0.0205, -0.0465,
         0.0374,  0.0054,  0.0052, -0.0479,  0.0083,  0.0467,  0.0439, -0.0387],
       device='cuda:0')
tensor([ 0.0600, -0.0585, -0.0001, -0.0544, -0.0606, -0.0642, -0.0260, -0.0550,
        -0.0360,  0.0609, -0.0008,  0.0635, -0.0602, -0.0604,  0.0365, -0.0497,
        -0.0607,  0.0292, -0.0529, -0.0548,  0.0604,  0.0639, -0.0607, -0.0597,
        -0.0434, -0.0577,  0.0260, -0.0475, -0.0629, -0.0371,  0.0320, -0.0603,
         0.0446,  0.0486,  0.0141,  0.0415,  0.0615, -0.0149, -0.0394, -0.0613,
        -0.0536, -0.0563, -0.0623, -0.0346, -0.0589, -0.0193,  0.0608, -0.0569,
         0.0261,  0.0400,  0.0431, -0.0226,  0.0036,  0.0610,  0.0387, -0.0612,
         0.0497,  0.0056,  0.0088, -0.0636,  0.0187,  0.0610,  0.0574, -0.0514],
       device='cuda:0')

Once again here the momentum is set to 1 so this behavior is very odd. Also I have noticed that in the low resolution branches the explosion of activations happens in a more extreme regime. I have tried an experiment where all the instances of BatchNorm were replaced with InstanceNorm and the training and validation regime was a healthy one with no exploding loss on validation.

Could you explain why it’s odd? Did you check the stats of the input activation for these batchnorm layers and were you seeing a mismatch?

Yes sure, so if you train a model on a single batch and use the same batch on validation, the batch norm having the momentum set to 1 should perform as if the track_running_stats flag is set to false. The mean and the var buffer should not change, due to the fact that these are computed with the formula from the note of the BatchNorm2D implementation .

,

Yes, which is also the behavior visible in my code snippet.
However, in your previous post you’ve printed stats from seemingly random input batches, so where does the expectation come from that these should be equal?
You are still training your model, don’t you?

In this experiment, I wanted to test that the pipeline that I created using the HRNet model works. I previously trained my model on a large dataset with a little over 100000 images and this phenomenon did not happen. I can’t pinpoint the reason why, if I am training on a single batch of size 2 the model on eval mode start from an exponential loss, and by debugging it I arrived at the conclusion that is due to the fact that BatchNorm layers have different behavior on train and on eval mode. If I set the track_running_stats to False the loss behaves normally, my problem is that I can’t see why if I set the momentum to 1 when overfitting on one batch the loss do not behaves as it should, the running_mean/var changes and the activations are exploding.

I managed to understand why this happens. Having the learnable parameters on, the model did backprop and the running_mean, and var were the same. The distributions did no longer match and the error was accumulated at each stage of the model. The solution is to set the track_running_stats to false, because it does not have the same behavior as setting the momentum to 1 and having only one batch. Thank you for your help.

1 Like