Gradients'dtype is not fp16 when using torch.cuda.amp

When I use torch.cuda.amp.autocast, the output is fp16, but the gradients are not fp16. Why?

import torch
a = torch.randn([4, 5], requires_grad=True, device='cuda')
b = torch.randn([5, 4], requires_grad=True, device='cuda')
with torch.autocast(device_type='cuda', dtype=torch.float16):
    c = a @ b
c.backward(torch.ones_like(c))
print(c.dtype)
print(a.grad.dtype)

cc @ptrblck
Could you help?

This is expected since the parameters will stay in float32 and will not be replaced with lower-precision dtypes.

Can you explain the internal mechanism? I guess that the loss is fp16, so the gradients produced are also fp16. However, the parameters are fp32, so pytorch will transform the gradients from fp16 to fp32. Right?