How to get ctx.xx outside the custom autograd function?

When I implement custom autograd function.
In the forward of class xxxFunction(torch.autograd.Function),
I get ctx.matches, then I want to get the ctx.matches outside this function.
I tried

@staticmethod
def get_matches(ctx):
    return ctx.matches

Then in the class xxx(torch.nn.Module), I run xxxFunction.get_matches().
get_matches() takes exactly 1 argument (0 given).

Thank you in advance!

1 Like

@albanD could you help me in this problem?

Hi,

You can access this from the backward method that gets the same ctx as input.
I don’t think you can extract it in a reliable way though.
Why do you need to do this? Can’t you just return it as another output of the forward method?

1 Like

I am facing the same issue. If I return it as another output of the forward method, how to adjust the backward method? @albanD
Many thanks!

The backward method should get and return None for everything that is not differentiable.

1 Like

Here is an example: