I implemented a custom function using
torch.autograd.Function, but this gives me an error when I enable AMP. It looks conceptually something like this:
class CustomFunction(torch.autograd.Function): @staticmethod def forward(ctx, inp, dim): ... ctx.save_for_backward(tensor_value) ... return answer @staticmethod @torch.autograd.function.once_differentiable def backward(ctx, grad_answer): ... tensor_value, = ctx.saved_tensors ... foo_tensor = grad_answer.addcmul(some_tensor, tensor_value, value=2) ...
This results in:
RuntimeError: !(has_different_input_dtypes && !config.promote_inputs_to_common_dtype_ && (has_undefined_outputs || config.enforce_safe_casting_to_output_ || config.cast_common_dtype_to_outputs_)) INTERNAL ASSERT FAILED at "/opt/conda/conda-bld/pytorch_1659484806139/work/aten/src/ATen/TensorIterator.cpp":405, please report a bug to PyTorch.
The problem is that
float32, while due to AMP
addcmul doesn’t natively promote to the widest input. What is the intended / best / most general way to circumnavigate this issue?
The forward pass is executed under AMP /
torch.autocast, so calling
addcmul there would promote to the widest input and not cause an issue - but seeing as
addcmul is needed in the backward pass and the backward pass is not executed under AMP this is an issue. Other operations like
add always promote as required, which is why you don’t see this issue for them.
Just manually hard-coding a cast of
float32 is not a good/general solution because it is brittle, relies on inherent knowledge of which tensor happens to end up as which type during AMP, and if anything changes might thereby force conversions to float32 where both inputs are actually float16. Also, the
backward() function could theoretically contain tens of calls to methods that don’t auto-promote. Trying to fix all those up with manual casts is undesirable, inflexible, and might force certain operations to run at higher precision than actually intended, or invoke unnecessary copies etc.
How is this normally handled?
addcmul be changed to auto-promote (implementation change)?
config.promote_inputs_to_common_dtype_ be temporarily enabled?
What dtype should the return value of the backward pass actually have? Always same as grad_answer? Same as inp? No restriction?