How to store temp variables with register_autograd without returning them as output?

Hi!

I’d like to write a custom op for a c++ operation to enable autograd and compile support.
However, my c++ function returns two tensors: 1) the output and 2) temporary values only needed for the backward pass.

When inheriting from torch.autograd.function, you can easily “wrap” the c++ function, store the temp values in ctx and return only the single output tensor. How can I do the same with register_autograd?

E.g.


@torch.library.custom_op("custom_fun", mutates_args=())
def custom_fun(x):
   out, tmp = _C.run(x)
   return out, tmp # this returns 2 outputs instead of only 'out'

def setup_context(ctx, input, output):
   out, tmp = output # I need to get tmp here somehow for backward, but don't want to have it as output
   ctx.tmp = tmp
   ctx.save_for_backward(input)

@torch.library.custom_op("backward_impl", mutates_args=())
def backward_impl(x, dx, tmp):
    return _C.run_back(x, dx, tmp)

def backward_fun(ctx, dx):
   tmp = ctx.tmp
   x = ctx.saved_tensors
   return backward_impl(x, dx, tmp)

custom_fun.register_autograd(backward_fun, setup_context=setup_context)

out, tmp = custom_fun(data) # this returns 2 values!