Hello,

I’ve been trying to apply automatic mixed precision on this VQ-VAE implementation

by following the pytorch documentation:

```
with autocast():
out, latent_loss = model(img)
recon_loss = criterion(out, img)
latent_loss = latent_loss.mean()
loss = recon_loss + latent_loss_weight * latent_loss
scaler.scale(loss).backward()
if scheduler is not None: #not using scheduler
scheduler.step()
scaler.step(optimizer)
scaler.update()
```

Unfortunately, the MSE appears once for a split second and then immediately goes to nan.

I don’t know if it is possible to use AMP on a VQ-VAE so any help would be appreciated,

Thanks