GradScaler: TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead

Hello,

I am running pytorch 2.5.0 on macos and I was training a simple VAE, attempting to run it with autocast.

This is the error I get:

scaler.unscale_(optimizer)
  File "..../lib/python3.10/site-packages/torch/amp/grad_scaler.py", line 335, in unscale_
    inv_scale = self._scale.double().reciprocal().float()
TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.

This is the code:

from torch.amp.autocast_mode import autocast
from torch.amp.grad_scaler import GradScaler
....
optimizer = optim.AdamW(vae.parameters(), lr=1e-3, fused=device == "cuda")
scaler = GradScaler(device)
with autocast(device):
    recon, mu, log_var = vae(data)
    recon_loss = F.mse_loss(recon, data, reduction="sum")
    # KL divergence
    kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    # Total loss
    loss = recon_loss + kl_loss

scaler.scale(loss).backward()
scaler.unscale_(optimizer) # <- breaks here
tot_norm = torch.nn.utils.clip_grad_norm_(vae.parameters(), 8.0).item()
scaler.step(optimizer)
scaler.update()

I assume gradscaler is not yet supported on MPS, any workarounds?