Strange issue: UNet model cannot overfit

Hi all! I’m new to this forum, but I’ve been having trouble for quite a few days trying to get a vanilla UNet model to overfit on a single image in training and haven’t ever run into this kind of issue before, so I thought I’d shout into the depths of the internet in case others have some insight into this.

The UNet is being trained for 2D image segmentation with ground truth masks (1 class). I am currently training on a single image to debug the network.

I have confirmed that my gradients and weights don’t explode (there is gradient clipping), the image normalization is correct (I normalize pixels from 0-255 to 0-1 and then normalize with dataset mean and stddev), the mask values are correct and are only in {0, 1}, the shapes of the output and ground truth masks are the same, and the loss is BCE with logits. I’ve also tried learning rates from 1e-9 to 0.1 just in case I had missed some level of granularity. I also have tried kaiming, xavier, and normal layer initializations (and the default PyTorch initialization).

The model is the one from this repository:

The training procedure is also the one found here:

The model trains very erratically and with high variance from train to train, and within the same training/epoch, the losses decrease for a couple iterations and then jump all over and end up increasing steadily after a few epochs (resulting in around a 0.7 loss value plateau), even though the gradients don’t explode. The output predictions are also rather dismal/unmeaningful.

I’ve also tried training a completely different model that was pretrained on a separate, related dataset, but using the same training and dataset code. Unfortunately, the behavior is exactly the same; I wonder if there’s an error in the training or dataset code, but I can’t seem to find one and I’ve tried debugging pretty much everything I could think of.

Would anyone be able to provide some guidance on this? Thanks so much!

1 Like

Your debugging steps sound good.

I had a look at the model and couldn’t find any obvious errors, so I tried to overfit a random image using this simple code:

model = UNet(1, 1)
x = torch.randn(1, 1, 224, 224)
target = torch.randint(0, 2, (1, 1, 224, 224)).float()

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.BCEWithLogitsLoss()

for epoch in range(30):
    output = model(x)
    loss = criterion(output, target)
    print('Epoch {}, loss {}, accuracy {}'.format(epoch, loss.item(),
        (target[0, 0] == (output[0, 0] > 0.0)).float().mean().item()))

and can reach an accuracy of ~0.9984 after 30 epochs, so the model seems to work fine.

Also, I cannot find any issues in the training code.
Could you share a single input and target, if they are open source, so that we can have a look?

Thanks so much, @ptrblck! I spent today debugging once again by rewriting the entire training pipeline from scratch and testing on incrementally more meaningful sets of data, and ended up finding the problem. It was very sneaky. from this starter code was actually fine, but in, net.eval() was called. However, net.train() was called only at the beginning of each epoch, while net.eval() was called every batch. So there was only meaningful training going on for one batch per epoch - no wonder!

Thanks for your help and quick response on this problem! It’s a worthwhile lesson learned. I’ll probably throw in a pull request. :slight_smile:

1 Like

I’m glad you’ve found this issue and the authors would be certainly happy, if you fix it. :wink:
Note that even after calling model.eval(), the model will be trained (Autograd will still calculate gradients), but the behavior of some layers will change. E.g. batch norm layers will use the running stats instead of the batch stats and dropout will be disabled. This would certainly explain a poor training performance.