I implemented a custom function using torch.autograd.Function
, but this gives me an error when I enable AMP. It looks conceptually something like this:
class CustomFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, inp, dim):
...
ctx.save_for_backward(tensor_value)
...
return answer
@staticmethod
@torch.autograd.function.once_differentiable
def backward(ctx, grad_answer):
...
tensor_value, = ctx.saved_tensors
...
foo_tensor = grad_answer.addcmul(some_tensor, tensor_value, value=2)
...
This results in:
RuntimeError: !(has_different_input_dtypes && !config.promote_inputs_to_common_dtype_ && (has_undefined_outputs || config.enforce_safe_casting_to_output_ || config.cast_common_dtype_to_outputs_)) INTERNAL ASSERT FAILED at "/opt/conda/conda-bld/pytorch_1659484806139/work/aten/src/ATen/TensorIterator.cpp":405, please report a bug to PyTorch.
The problem is that tensor_value
is float32
, while due to AMP grad_answer
is float16
, and addcmul
doesn’t natively promote to the widest input. What is the intended / best / most general way to circumnavigate this issue?
The forward pass is executed under AMP / torch.autocast
, so calling addcmul
there would promote to the widest input and not cause an issue - but seeing as addcmul
is needed in the backward pass and the backward pass is not executed under AMP this is an issue. Other operations like add
always promote as required, which is why you don’t see this issue for them.
Just manually hard-coding a cast of grad_answer
to float32
is not a good/general solution because it is brittle, relies on inherent knowledge of which tensor happens to end up as which type during AMP, and if anything changes might thereby force conversions to float32 where both inputs are actually float16. Also, the backward()
function could theoretically contain tens of calls to methods that don’t auto-promote. Trying to fix all those up with manual casts is undesirable, inflexible, and might force certain operations to run at higher precision than actually intended, or invoke unnecessary copies etc.
How is this normally handled?
Can addcmul
be changed to auto-promote (implementation change)?
Can config.promote_inputs_to_common_dtype_
be temporarily enabled?
What dtype should the return value of the backward pass actually have? Always same as grad_answer? Same as inp? No restriction?