I want use amp in https://github.com/open-mmlab/mmcv/blob/master/mmcv/ops/roi_align.py
.
I refer to https://pytorch.org/docs/stable/notes/amp_examples.html#functions-with-multiple-inputs-or-autocastable-ops
, and Apply custom_fwd and custom_bwd (with no arguments) to forward and backward respectively. I modify code :
class RoIAlignFunction(Function):
... ...
@staticmethod
@custom_fwd
def forward(ctx,
input,
rois,
output_size,
spatial_scale=1.0,
sampling_ratio=0,
pool_mode='avg',
aligned=True):
ctx.output_size = _pair(output_size)
ctx.spatial_scale = spatial_scale
... ...
@staticmethod
@once_differentiable
@custom_bwd
def backward(ctx, grad_output):
rois, argmax_y, argmax_x = ctx.saved_tensors
... ...
The code still report error:
File "/usr/local/python16/lib/python3.7/site-packages/mmcv/ops/roi_align.py", line 71, in forward
aligned=ctx.aligned)
RuntimeError: expected scalar type Half but found Float
what did custom_fwd
and custom_bwd
do specifically?
This refer to issue https://github.com/pytorch/pytorch/issues/47906