Most related issues I can find on the Internet are that model performance degrades when switching from “model.train()” to “model.eval()”. However, I had exactly the opposite problem: model performance degrades when switching from “model.eval()” to “model.train()”.
I am training a UNet for a segmentation task on 3 GPUs using PyTorch DDP. The only modules that behave differently in training and evaluation mode are “SyncBatchNorm” layers in each convolution block. I met this problem when trying to resume training from the saved model and optimizer state dicts. The loss are very high in the first few steps (though it quickly goes down later). This happens even when I set the initial learning rate to a very low value.
I isolate the problem using the following code:
with torch.no_grad(), torch.autocast("cuda", dtype=torch.bfloat16, enabled=True):
model.eval()
mask_pred_batch = model(image_batch)
loss = compute_loss(mask_pred_batch, mask_gt_batch) # calculate loss
if local_rank == 0:
print(loss.item())
model.train()
mask_pred_batch = model(image_batch)
loss = compute_loss(mask_pred_batch, mask_gt_batch) # calculate loss
if local_rank == 0:
print(loss.item())
And the printed loss values are:
0.10580593347549438
17.66183853149414
Essentially, the same batch is passed into the model twice, one under evaluation mode and the other under training mode. The batch comes from the same training data in the previous training session, so the first loss value is very low. However, just changing the model to training mode dramatically increase the loss to the level of an untrained model.
The image size is 512x512, the batch size per GPU is 16 (so in total 48 since I use SyncBatchNorm and have 3 GPUs). The same batch size is used before and after resuming. The BatchNorm2D’s momentum is 0.1 (the default value). The PyTorch version is 2.5.1.
I cannot understand why this happens. Hope someone can give me some hints and help to solve the problem.