GradScaler: TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead

Hello,

I am running pytorch 2.5.0 on macos and I was training a simple VAE, attempting to run it with autocast.

This is the error I get:

scaler.unscale_(optimizer)
  File "..../lib/python3.10/site-packages/torch/amp/grad_scaler.py", line 335, in unscale_
    inv_scale = self._scale.double().reciprocal().float()
TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.

This is the code:

from torch.amp.autocast_mode import autocast
from torch.amp.grad_scaler import GradScaler
....
optimizer = optim.AdamW(vae.parameters(), lr=1e-3, fused=device == "cuda")
scaler = GradScaler(device)
with autocast(device):
    recon, mu, log_var = vae(data)
    recon_loss = F.mse_loss(recon, data, reduction="sum")
    # KL divergence
    kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    # Total loss
    loss = recon_loss + kl_loss

scaler.scale(loss).backward()
scaler.unscale_(optimizer) # <- breaks here
tot_norm = torch.nn.utils.clip_grad_norm_(vae.parameters(), 8.0).item()
scaler.step(optimizer)
scaler.update()

I assume gradscaler is not yet supported on MPS, any workarounds?

The workaround is to patch the torch/amp/grad_scaler.py code to avoid the conversion to double. This line for inv_scale:

        # FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
        assert self._scale is not None
        inv_scale = self._scale.double().reciprocal().float() # No good on mps

needs to change to avoid the scaling to double:

        # FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
        assert self._scale is not None
        if self._scale.device.type == "mps":
            inv_scale = self._scale.reciprocal() # Leave as float32 on mps!
        else:
            inv_scale = self._scale.double().reciprocal().float()