Hello,
I would like to implement automatic mixed precision for some custom autograd functions, but decorating them using custom_fwd(cast_inputs=torch.float16)
disables gradient tracking for the inputs unless they are already of the desired type (that is, when casting does not occur), as demonstrated below.
import torch # Version 2.2.0
from torch.cuda.amp import autocast, custom_fwd
class IdentityFunc(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx, input):
print(input.dtype, input.requires_grad)
return input
input = torch.tensor(0, dtype=torch.float32, device='cuda', requires_grad=True)
with autocast():
output = IdentityFunc.apply(input) # Prints torch.float16 False
input = torch.tensor(0, dtype=torch.float16, device='cuda', requires_grad=True)
with autocast():
output = IdentityFunc.apply(input) # Prints torch.float16 True
This idiosyncrasy is not mentioned in the docs, so it might be a bug? If so, I can file an issue on GitHub. Otherwise, is there a means of ensuring tensors retain their requires_grad
values through the casting process? In many cases, it is necessary to infer whether a particular input requires its gradients to be calculated during the backward pass, which cannot be done due to this behaviour.
Thank you.