I have a custom op looks like below:
def staticmethod_combined(custom_decorator):
return lambda func: staticmethod(custom_decorator(func))
def MyCustomOp(torch.autograd.Function):
@staticmethod_combined(custom_fwd(cast_inputs=torch.float16))
def forward(ctx, t1, t2, t3):
## do forward
return output
@staticmethod_combined(custom_bwd)
def backward(ctx, grad_out):
## do backward
return grad_t1, grad_t2, grad_t3
I want t1 and t2 is torch.float16
tensor, and t3 to be torch.float32
. Now my solution is to manually add explicit type conversions like:
t3 = t3.float()
I wanner:
- Is there a more elegant way of type conversion
- Do I need to convert grad back to half in backward after manually converting the type of t3, such as
grad_t3.half()
thanks for any help~