How to use torch.autograd.Function with AMP?

I implemented a custom function using torch.autograd.Function, but this gives me an error when I enable AMP. It looks conceptually something like this:

class CustomFunction(torch.autograd.Function):

	@staticmethod
	def forward(ctx, inp, dim):
		...
		ctx.save_for_backward(tensor_value)
		...
		return answer

	@staticmethod
	@torch.autograd.function.once_differentiable
	def backward(ctx, grad_answer):
		...
		tensor_value, = ctx.saved_tensors
		...
		foo_tensor = grad_answer.addcmul(some_tensor, tensor_value, value=2)
		...

This results in:

RuntimeError: !(has_different_input_dtypes && !config.promote_inputs_to_common_dtype_ && (has_undefined_outputs || config.enforce_safe_casting_to_output_ || config.cast_common_dtype_to_outputs_)) INTERNAL ASSERT FAILED at "/opt/conda/conda-bld/pytorch_1659484806139/work/aten/src/ATen/TensorIterator.cpp":405, please report a bug to PyTorch.

The problem is that tensor_value is float32, while due to AMP grad_answer is float16, and addcmul doesn’t natively promote to the widest input. What is the intended / best / most general way to circumnavigate this issue?

The forward pass is executed under AMP / torch.autocast, so calling addcmul there would promote to the widest input and not cause an issue - but seeing as addcmul is needed in the backward pass and the backward pass is not executed under AMP this is an issue. Other operations like add always promote as required, which is why you don’t see this issue for them.

Just manually hard-coding a cast of grad_answer to float32 is not a good/general solution because it is brittle, relies on inherent knowledge of which tensor happens to end up as which type during AMP, and if anything changes might thereby force conversions to float32 where both inputs are actually float16. Also, the backward() function could theoretically contain tens of calls to methods that don’t auto-promote. Trying to fix all those up with manual casts is undesirable, inflexible, and might force certain operations to run at higher precision than actually intended, or invoke unnecessary copies etc.

How is this normally handled?
Can addcmul be changed to auto-promote (implementation change)?
Can config.promote_inputs_to_common_dtype_ be temporarily enabled?
What dtype should the return value of the backward pass actually have? Always same as grad_answer? Same as inp? No restriction?

There is a method to compute the promoted type based on the dtype of the inputs. torch.promote_types — PyTorch 2.0 documentation. Does this solve your issue?

So if I understand correctly I would have to do something like:

dtype = torch.promote_types(some_tensor, tensor_value)
dtype = torch.promote_types(dtype, grad_answer)
foo_tensor = grad_answer.to(dtype=dtype).addcmul(some_tensor.to(dtype), tensor_value.to(dtype), value=2)

for each call to a non-auto-promoting function like addcmul in backward()?
I guess for simple isolated cases this could be a workaround.

How about wrapping the backward() call in a torch.autocast context?

The documentation for torch.autocast says (Automatic Mixed Precision package - torch.amp — PyTorch 2.1 documentation):

autocast should wrap only the forward pass(es) of your network, including the loss computation(s). Backward passes under autocast are not recommended. Backward ops run in the same type that autocast used for corresponding forward ops.

The first half of that seems to dissuade us from my proposition. Howcome?
Then what does the second bit mean?

The second half: autograd records the types that the operations were performed in in forward, and does casting automatically so that gradients match the dtypes of the inputs.

Have you checked out CUDA Automatic Mixed Precision examples — PyTorch 2.0 documentation? When using autocast you should decorate your custom Function forward and backward like so to prevent type mismatch errors:

class MyMM(torch.autograd.Function):
    @staticmethod
    @custom_fwd
    def forward(ctx, a, b):
        ctx.save_for_backward(a, b)
        return a.mm(b)
    @staticmethod
    @custom_bwd
    def backward(ctx, grad):
        a, b = ctx.saved_tensors
        return grad.mm(b.t()), a.t().mm(grad)

I hadn’t seen that yet.
Digging into the code it seems to do pretty much what I naively suggested - it executes the backward method with torch.autocast if the forward part had it enabled. This is definitely the intended and clean solution then, thanks.

I guess the only thing left I’m unsure about is why in general the backward pass shouldn’t be run under torch.autocast, but running arbitrary backward methods under torch.autocast like custom_bwd does is fine.

it executes the backward method with torch.autocast if the forward part had it enabled.

Ahh interesting, that would make sense.

I think the difference is that whatever you do in custom Function forward is opaque to the autograd engine, so it wouldn’t be able to do that automatic casting I mentioned in this particular case.