Conversion to .bfloat16() makes require_grad False

>>> x = torch.randn(2,2, requires_grad=True)
>>> x.requires_grad
True
>>> y = x.half()
>>> y.requires_grad
True
>>> z = x.bfloat16()
>>> z.requires_grad
False

Could someone help with fixing this?

It does look like a bug, thanks for reporting this.
You can follow the progress in that issue: https://github.com/pytorch/pytorch/issues/28548

1 Like