AMP working with custom op

I want use amp in https://github.com/open-mmlab/mmcv/blob/master/mmcv/ops/roi_align.py.

I refer to https://pytorch.org/docs/stable/notes/amp_examples.html#functions-with-multiple-inputs-or-autocastable-ops, and Apply custom_fwd and custom_bwd (with no arguments) to forward and backward respectively. I modify code :

class RoIAlignFunction(Function):
    ... ...
    @staticmethod
    @custom_fwd
    def forward(ctx,
                input,
                rois,
                output_size,
                spatial_scale=1.0,
                sampling_ratio=0,
                pool_mode='avg',
                aligned=True):
        ctx.output_size = _pair(output_size)
        ctx.spatial_scale = spatial_scale
        ... ...
 
    @staticmethod
    @once_differentiable
    @custom_bwd
    def backward(ctx, grad_output):
        rois, argmax_y, argmax_x = ctx.saved_tensors
        ... ...

The code still report error:

  File "/usr/local/python16/lib/python3.7/site-packages/mmcv/ops/roi_align.py", line 71, in forward
    aligned=ctx.aligned)
RuntimeError: expected scalar type Half but found Float

what did custom_fwd and custom_bwd do specifically?

This refer to issue https://github.com/pytorch/pytorch/issues/47906