Why is Normalization causing my network to have exploding gradients in training?

I’ve built a network (In Pytorch) that performs well for image restoration purposes. I’m using an autoencoder with a Resnet50 encoder backbone, however, I am only using a batch size of 1. I’m experimenting with some frequency domain stuff that only allows me to process one image at a time.

I have found that my network performs reasonably well, however, it only behaves well if I remove all batch normalization from the network. Now of course batch norm is useless for a batch size of 1 so I switched over to group norm, designed for this purpose. However, even with group norm, my gradient explodes. The training can go very well for 20 - 100 epochs and then game over. Sometimes it recovers and explodes again.

I should also say that in training, every new image fed in is given a wildly different amount of noise to train for random noise amounts. This has been done in other papers however and they say that batch norm helps to address the covariate shift.

I’m scratching my head at this one and I’m wondering if anyone has suggestions. I’ve dialed in my learning rate and clipped the max gradients but this isn’t really solving the actual issue. I can post some code but I’m not sure where to start and hoping someone could give me a theory. Any ideas as to why my network trains well without GN but badly with it? Thanks!