Model.eval() gets strange behaviour

Hello,
I am training a UNet model on Colab for Image Segmentation. The problem is that at the end of each epoch, when I do the evaluation step (moving the net to model.eval()), the results are really strange. Basically, the performances degrade a lot, and it seems that the net is not learning anything. I tried to plot the predicted masks and the predicted values are really bad, with a confusing range.
There two other behavior that are very strange:

  • if I set the net to model.train(), and I try to predict the mask for the same image (a validation image) I get really better results
  • I tried to train all net using model.eval() in the evaluation step. At the last stage, I perform the evaluation on the test dataset, in order to analyze the real performances. In this last step the performances are really good, comparable with the ones reached when I used model.train(). I want to point out that for the test prediction I use model.eval() exactly as for the validation step.

I know that using model.eval() the BatchNormalization and Dropout layers have different behaviors, and in the UNet architecture there are some BatchNormalization layers. I tried to set the track_running_stats parameter to False, but nothing changes.

Currently I am using a batch size of 4 or 8, that is quite small, but I cannot use bigger values due to GPU limitation.

Here you can find a snippet of my code, inside a for loop for the epochs:

net.train()
epoch_loss = 0
for batch in train_loader:
   imgs = batch['image']
   true_masks = batch['mask']
   imgs = imgs.to(device=device, dtype=torch.float32)
   mask_type = torch.float32 if net.n_classes == 1 else torch.long
   true_masks = true_masks.to(device=device, dtype=mask_type)

   masks_pred = net(imgs)
   loss = criterion(masks_pred, true_masks.squeeze(1))
   epoch_loss += loss.item()

   optimizer.zero_grad()
   loss.backward()
   nn.utils.clip_grad_value_(net.parameters(), 0.1)
   optimizer.step()

# Validation
val_ce, val_iou = eval_net(net, model_arch, val_loader, device)

The eval_net function is defined as:

net.eval()
mask_type = torch.float32 if net.n_classes == 1 else torch.long
n_val = len(loader)
tot_ce = 0
for batch in loader:
   imgs, true_masks = batch['image'], batch['mask']
   imgs = imgs.to(device=device, dtype=torch.float32)
   true_masks = true_masks.to(device=device, dtype=mask_type)

   with torch.no_grad():
       mask_pred = net(imgs)

   tot_ce += F.cross_entropy(mask_pred, true_masks.squeeze(1)).item()
         

Any idea on the possible reason of this problem?

I don’t think this batch size is suitable for batchnorm layer. It may decrease the performance instead of increasing it. You have to remove it or maybe use other normalization layers like instance norm or layer norm.

Ok thanks, which is a minimum batch size suitable for BN in your opinion?

I usually try to use as bigger as possible. 512, 1024, 2048. It needs high gpu memory. But I have the same problem as yours when I work on my laptop. I just remove the batch-norm to avoid that.

1 Like

Ok, I tried (using another environment) a batch size of 32, that seems quite good in order to guarantee robustness of means and variances. Anyway the problem occurs even in this case. It seems a bit strange…I cannot use BN unless I have a big cluster.
Which other kind of normalization do you suggest?

Did you test without normalization? To see if it works or not. Sometime it works quite good without BN.

I also find this problem happened if you donot carefully normalize your dataset. If the mean and std is large different in mini-batch, the model.train() performance will much better than model.eval() due to batchNormalization. The way to solve it is normlize the whole dataset to mean = 0, sted = 1.

I have alse met this problem. Who can give me some suggestions? Thanks!

I have a specialized case where I load in an ONNX model with weights and use the onnx2torch library to convert this network to pytorch format before training and validation. I noticed the same issue with model.eval(), and the only thing that fixed it for me was to upgrade the “opset version” of the ONNX network (pre-conversion using the onnx2torch library) to version 13.

I started with my network in Matlab as a .mat file, and used exportONNXNetwork(net, filename, OpsetVersion = 13) to accomplish this. Then went about loading the onnx model into pytorch and converting to the pytorch format. Now, model.eval() does not hinder performance in validation at all.

This was a frustrating problem to have so I hope this solution is helpful beyond my particular application! Maybe there is some analog to OpsetVersion even if you are not dealing with ONNX like I am.