Loss problem in net finetuning

Thanks for the code. I am currently working on it creating some dummy data and targets.
One thing I’ve seen so far is the usage of transformation.
Since you are working on a segmentation task, I assume you have segmentation maps as the target.
I cannot see, how your Dataset is implemented, but if you are using some random transformations like RandomResizedCrop, and flipping, you have to take care of applying them also on your target.
Otherwise your input will be transformed and the model might have a hard time to learn the relationship between the input and target.

The easiest way would be to use the functional API of torchvision.
Here is a small example I created a while ago.

Let me know, if this helps!