Mixed precision VQ-VAE makes NaN loss

I’ve been trying to apply automatic mixed precision on this VQ-VAE implementation
by following the pytorch documentation:

        with autocast():
            out, latent_loss = model(img)
            recon_loss = criterion(out, img)
            latent_loss = latent_loss.mean()
            loss = recon_loss + latent_loss_weight * latent_loss

        if scheduler is not None: #not using scheduler

Unfortunately, the MSE appears once for a split second and then immediately goes to nan.
I don’t know if it is possible to use AMP on a VQ-VAE so any help would be appreciated,

I assume the model converges in float32 and the loss doesn’t blow up?
If so, could you check the parameters after each step for invalid values (NaNs and Infs)?

Yes, the model converges in float32. I have tried checking for NaNs and Infs with this code:

        for param in model.parameters():
            if torch.isnan(param.data).any() or torch.isinf(param.data).any():

but it doesn’t show anything. The loss is around 0.38 after the first step and then goes to NaN beacause the tensors returned by out, latent_loss = model(img) are filled with only NaNs.
The anomaly detection gives me “RuntimeError: Function ‘MseLossBackward’ returned nan values in its 0th output.”

I have isolated the problem and it seems to be in the Decoder since the NaNs appear after line 213
dec_t = self.dec_t(quant_t) in vqvae.py

Thanks for helping

Thanks for the debugging. Could you disable autocast for self.dec_t and check the output range? This could narrow down, if the decoder outputs are indeed overflowing or if the NaNs are created in a different operation.
I’m also not sure how the quantization works in this model. Could you explain it a bit how it would quantize the values for FP32?

Putting with torch.cuda.amp.autocast(enabled=False): before dec_t = self.dec_t(quant_t) doesn’t change anything, but if I put it in forward pass of the decoder I now get the error RuntimeError: Function 'CudnnConvolutionBackward' returned nan values in its 1th output. and the traceback points to the forward pass. Sorry if I’m doing this wrong, I’m not quite sure where I should disable autocast.

No, you debugging sounds alright.
If you disable autocast for the decoder, were you able to isolate which method introduces the NaNs? Is it still the decoder?

I cannot not find any NaNs in the foward pass when disabling autocast for the decoder, but the traceback of the error says it is still the decoder that produces NaNs, more specifically a convolution layer in the decoder:

\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
lib\site-packages\torch\nn\modules\conv.py", line 423, in forward
return self._conv_forward(input, self.weight)
lib\site-packages\torch\nn\modules\conv.py", line 420, in _conv_forward
self.padding, self.dilation, self.groups)
(function _print_stack)

I have modified the Sequential block into a line of code for each layer and the traceback now points to the first layer of the decoder which is a simple convolution: nn.Conv2d(in_channel, channel, 3, padding=1)

Could you check the inputs, output, as well as the parameters of this layer and their stats in particular?
Since this layer is now running in float32, the values are either huge in their magnitude, or another layer is causing the invalid values.

I have checked the inputs and found some NaNs so I found the responsible layer which was a nn.ConvTranspose2d(embed_dim, embed_dim, 4, stride=2, padding=1 ) so I made it run in float32 and now there’s no more NaNs but the loss explodes to values like 6527757.50000 . So again, I find the responsible layer, self.quantize_t = Quantize(embed_dim, n_embed) but when I try to make it run in float32 I get the error:

\vqvae.py", line 46, in forward
+ self.embed.pow(2).sum(0, keepdim=True)
RuntimeError: expected scalar type Half but found Float

during the forward pass of the layer.
At this point, im not sure this is even feasible.

The exploding loss sounds concerning, which is why I asked about it in the first post, as it would blow up mixed-precision training.
I don’t quite understand the last error. Is the Quantize layer depending on the usage of amp or why is it suddenly not working in float32?

I don’t really understand it either. It only happens when I try to run mixed precision with the problematic layers, including the Quantize layer in float32. If everything is in float32 theres no problem. Even if theres is a way to fix this I’m not sure I would get any significant speedup since most layers would have to run in float32.