Casting Inputs Using custom_fwd Disables Gradient Tracking

Hello,

I would like to implement automatic mixed precision for some custom autograd functions, but decorating them using custom_fwd(cast_inputs=torch.float16) disables gradient tracking for the inputs unless they are already of the desired type (that is, when casting does not occur), as demonstrated below.

import torch # Version 2.2.0
from torch.cuda.amp import autocast, custom_fwd


class IdentityFunc(torch.autograd.Function):
    @staticmethod
    @custom_fwd(cast_inputs=torch.float16)
    def forward(ctx, input):
        print(input.dtype, input.requires_grad)
        return input


input = torch.tensor(0, dtype=torch.float32, device='cuda', requires_grad=True)
with autocast():
    output = IdentityFunc.apply(input) # Prints torch.float16 False

input = torch.tensor(0, dtype=torch.float16, device='cuda', requires_grad=True)
with autocast():
    output = IdentityFunc.apply(input) # Prints torch.float16 True

This idiosyncrasy is not mentioned in the docs, so it might be a bug? If so, I can file an issue on GitHub. Otherwise, is there a means of ensuring tensors retain their requires_grad values through the casting process? In many cases, it is necessary to infer whether a particular input requires its gradients to be calculated during the backward pass, which cannot be done due to this behaviour.

Thank you.

The reason for the difference is due to a cast and the disabled autograd inside the forward and backward methods unrelated to amp:

class IdentityFunc(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        print(input.dtype, input.requires_grad)
        a = input * 2
        print(a.requires_grad)
        return input * 2
    
    @staticmethod
    def backward(ctx, grad):
        return grad * 2

input = torch.tensor(0, dtype=torch.float32, device='cuda', requires_grad=True)
output = IdentityFunc.apply(input) 
# torch.float32 True
# False

Since the first code will apply an operation (the to(torch.float16) cast) in the input, the new input tensor will set its .requires_grad attribute to False in the same way it happend for a in my example.
I don’t think there is an easy solution using the custom_fwd decorator while still keeping the .requires_grad attribute alive and you might need to cast the inputs manually in your custom autograd.Function if its called inside an autocast region.

Thank you for the explanation, that makes sense. I will either manually cast the inputs as you suggest or pass arguments to the forward pass of my custom function denoting which tensors need their gradients computed.