Hello,
I am training a UNet model on Colab for Image Segmentation. The problem is that at the end of each epoch, when I do the evaluation step (moving the net to model.eval()), the results are really strange. Basically, the performances degrade a lot, and it seems that the net is not learning anything. I tried to plot the predicted masks and the predicted values are really bad, with a confusing range.
There two other behavior that are very strange:
- if I set the net to model.train(), and I try to predict the mask for the same image (a validation image) I get really better results
- I tried to train all net using model.eval() in the evaluation step. At the last stage, I perform the evaluation on the test dataset, in order to analyze the real performances. In this last step the performances are really good, comparable with the ones reached when I used model.train(). I want to point out that for the test prediction I use model.eval() exactly as for the validation step.
I know that using model.eval() the BatchNormalization and Dropout layers have different behaviors, and in the UNet architecture there are some BatchNormalization layers. I tried to set the track_running_stats parameter to False, but nothing changes.
Currently I am using a batch size of 4 or 8, that is quite small, but I cannot use bigger values due to GPU limitation.
Here you can find a snippet of my code, inside a for loop for the epochs:
net.train()
epoch_loss = 0
for batch in train_loader:
imgs = batch['image']
true_masks = batch['mask']
imgs = imgs.to(device=device, dtype=torch.float32)
mask_type = torch.float32 if net.n_classes == 1 else torch.long
true_masks = true_masks.to(device=device, dtype=mask_type)
masks_pred = net(imgs)
loss = criterion(masks_pred, true_masks.squeeze(1))
epoch_loss += loss.item()
optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_value_(net.parameters(), 0.1)
optimizer.step()
# Validation
val_ce, val_iou = eval_net(net, model_arch, val_loader, device)
The eval_net function is defined as:
net.eval()
mask_type = torch.float32 if net.n_classes == 1 else torch.long
n_val = len(loader)
tot_ce = 0
for batch in loader:
imgs, true_masks = batch['image'], batch['mask']
imgs = imgs.to(device=device, dtype=torch.float32)
true_masks = true_masks.to(device=device, dtype=mask_type)
with torch.no_grad():
mask_pred = net(imgs)
tot_ce += F.cross_entropy(mask_pred, true_masks.squeeze(1)).item()
Any idea on the possible reason of this problem?