Nan training and validation losses of unet model

Hello all, I am using a simple UNET model which I adapted from here. The model works well with the Pytorch dataset class of the author.

However, when I try to use this model on a new Pytorch dataset (which I created), it returns nan training loss and nan validation losses. I am using nn.CrossEntropyLoss(). I believe my Pytorch dataset class works fine but I am missing something.

I have made my notebook openly available and can be accessed here. I am obtaining sentinel-2 optical imagery and trying to segment floodwater from the imagery, with the unet model. So basically binary segmentation (floodwater or not floodwater).

Any help would be greatly appreciated. Thanks! :slight_smile:

Hey I went through your implementation. Quite impressive work on using optical and SAR images together. What I would suggest is, you can check the following:

  1. Check if the loss is nan even when using simple cross entropy loss, try removing the dice loss.
  2. If the loss is still nan when using cross entropy loss, change datatype to long for calculating cross entropy loss, and float when calculating dice loss.

Hope this helps.

1 Like

Thank you very much for your reply! I tried your suggestion, but later found that I used a different way of normalizing the input images from the author which caused the model not to work as expected.