Resnet101 sensitive to previous evaluations in train mode?

I’m retraining resnet101 for an image classification task, and observe that my models behave differently in eval mode if it has previously been run in training mode. Here is a code example:

from torchvision import models
import torch
from PIL import Image
from torchvision import transforms

transform = transforms.Compose([  # [1]
    transforms.Resize(256),  # [2]
    transforms.CenterCrop(224),  # [3]
    transforms.ToTensor(),  # [4]
    transforms.Normalize(  # [5]
        mean=[0.485, 0.456, 0.406],  # [6]
        std=[0.229, 0.224, 0.225]  # [7]
    )])

img = Image.open("C:/image.jpg")
img_t = transform(img)
batch_t = torch.unsqueeze(img_t, 0).cuda()
resnet = models.resnet101(pretrained=True)
resnet.cuda()
resnet.eval()
print(max(resnet(batch_t)[0]))
# tensor(13.7692, device='cuda:0', grad_fn=<SelectBackward>)
# Peforming a model evaluation in training mode:
resnet.train()
dummy = resnet(batch_t)
# Switching back to eval mode, and the model output is different:
resnet.eval()
print(max(resnet(batch_t)[0]))
# tensor(13.2980, device='cuda:0', grad_fn=<SelectBackward>)

Apparently, running the model in training mode (without updating anything) changes the behavior of the model in eval mode.
I assume I’m missing something, but I wouldn’t expect the statements
resnet.train()
dummy = resnet(batch_t)
to make any changes to the model?

The forward pass will update the running statistics of all batch norm layers and thus the next output in model.eval() will be changed.
Note that model.train() and model.eval() are exactly used to switch this behavior. I.e. during training the batch statistics will be used to normalize the data and the running stats will be updated, while during evaluation the running stats will be applied to normalize the input data. (Unrelated to this issue, but also dropout will be disabled during model.eval().)

1 Like

I thought I understood the answer at the time, but now I’m less sure.
Over which periods are the batch norm running statistics collected?
Are they reset or discarded each time I switch the model between modes?
This would work fine if you run through all your training images in one sequence (which is common, I guess), but I implement a sampling procedure of randomly drawing images from randomly chosen classes, in order to imitate a balanced training set. During training, I (frequently) switch back and forth between traning and eval mode, to monitor the progress.
I guess this is a bad idea, if it resets running statistics all the time, and these really should be calibrated from the entire data set?

During model.train() the batch statistics will be used to update the running stats using the momentum.

No. You can check them via bn.running_mean and bn.running_var.

1 Like

Thanks for always answering quickly and accurateky, ptrblck!
I’ve tried to dig into the issue some more, and in another post I found this code where you implement this stuff: “https://github.com/ptrblck/pytorch_misc/blob/master/batch_norm_manual.py”.
If momentum is in use, you set
exponential_average_factor = self.momentum
Your update of the running mean looks like this:
self.running_mean = exponential_average_factor * mean
+ (1 - exponential_average_factor) * self.running_mean
If momentum is close to 1 (typically 0.9 in my application), doesn’t this calculation place most of the weight on the latest observation, and close to none on the history?
Shouldn’t it be the other way around for this calculation to accumulate a stable Ornstein-Uhlenbeck type average?:
exponential_average_factor = 1-self.momentum
Or should momentum be close to 0?

It depends how you specify the momentum and note that PyTorch uses a probably unexpected default value of 0.1.
I’ve implemented the manual batchnorm approach to stick to the PyTorch definition. From the docs:

This momentum argument is different from one used in optimizer classes and the conventional notion of momentum. Mathematically, the update rule for running statistics here is x_hat_new = (1 - momentum) * x_hat + momentum x_t , where x_hat is the estimated statistic and x_t is the new observed value.

1 Like