Numeric Instability from batchnorm?

Hi pytorch,
I’m creating a VAE and the encoder makes use of batchnorm for many of the layers. However, I think its introducing some numeric instability. On each backwards pass I check for very high gradients and one layer always shows up with extremely high gradients for weights and biases right before my loss goes to nan and crashes out. I’ve tried disabling learned weights and biases for the offending layer, but that only moves the problem to the next one, which I guess makes sense. Does anyone have any experience with something like this?

EDIT: I revised my checks and I’ve found that both the weights and biases of each of the first 3 batch norms that operate on my input data end up with gradients with very high norms. This causes instability even though I’m using gradient clipping.