U-Net collapse during Training

Hey all,

I am currently training a U-Net to predict a noise in a given image (diffusion) and encounter some weird issues. Mainly, the loss spikes at some point in time and from then stays roughly the same, shown in this image:

I use AdamW and a simple MSE between the noise and the output of the U-Net. For my experiment now I use a batch size of 1, but the issue also appears with larger batch sizes. Regarding this I have two specific questions:

1. What is causing the spike? (How can it be avoided?)
2. Why does the model stop learning after the spike?

I tried to fix the issue with gradient clipping and it worked, but the results turn out to be bad. So I want to fix the problem at its “root”. I ran some experiments, some insights are shown below.

The moving averages of Adam are also affected:

I ran multiple (9) experiments with the same random seed (pytorch_lightnings seed_everyhting) and I noticed some weird behaviour; up to a certain step (14) the behaviour is always almost identical, from then, losses, gradients etc. start to differ:


Any idea what happens at step 14 or 15 that changes the model outcome?

Left side Model Output, Right Side GT (before collapse):
image
Left side Model Output, Right Side GT (“during” collapse):
image
Left side Model Output, Right Side GT (after collapse):
image

I also noticed that this only happens for large models and not smaller ones.

Specs:
Ubuntu 20.04
RTX 3090
CUDA 11.7
Python 3.7.13
Pytorch 1.11.0
Pytorch Lightning 1.6.1 (used for training)

I am thankfull for any suggestions regarding the cause of the issues or suggestions on how I can continue my analysis. Thanks!