Hi!
I’d like to write a custom op for a c++ operation to enable autograd and compile support.
However, my c++ function returns two tensors: 1) the output and 2) temporary values only needed for the backward pass.
When inheriting from torch.autograd.function, you can easily “wrap” the c++ function, store the temp values in ctx and return only the single output tensor. How can I do the same with register_autograd?
E.g.
@torch.library.custom_op("custom_fun", mutates_args=())
def custom_fun(x):
out, tmp = _C.run(x)
return out, tmp # this returns 2 outputs instead of only 'out'
def setup_context(ctx, input, output):
out, tmp = output # I need to get tmp here somehow for backward, but don't want to have it as output
ctx.tmp = tmp
ctx.save_for_backward(input)
@torch.library.custom_op("backward_impl", mutates_args=())
def backward_impl(x, dx, tmp):
return _C.run_back(x, dx, tmp)
def backward_fun(ctx, dx):
tmp = ctx.tmp
x = ctx.saved_tensors
return backward_impl(x, dx, tmp)
custom_fun.register_autograd(backward_fun, setup_context=setup_context)
out, tmp = custom_fun(data) # this returns 2 values!