Issues with custom torch.autograd.Function and custom jvp method

I’m trying to define a custom torch.autograd.Function with user-defined forward, setup_context, backward, and jvp methods. I want to have control over the Jacobian-vector product (JVP) and be able to call it explicitly. However, I encountered an issue where the jvp method raises a TypeError because it expects a ctx argument, even though I’m calling it as a static method.

Below is a minimal example to illustrate the problem:

Example Code

import torch

class MyAutogradFunction(torch.autograd.Function):
    @staticmethod
    def forward(input):
        # Perform forward computation
        return input * 2  # Example operation
    
    @staticmethod
    def setup_context(ctx, inputs, outputs):
        input, = inputs
        ctx.save_for_backward(input)
        #ctx.save_for_forward(input)
        ctx.input = input

    @staticmethod
    def backward(ctx, grad_output):
        # Retrieve saved tensors
        input, = ctx.saved_tensors
        # Compute gradient w.r.t input
        grad_input = grad_output * 2  # Example gradient computation
        return grad_input

    @staticmethod
    def jvp(ctx, v):
        # Compute the Jacobian-vector product
        print('Using custom jvp')
        input = ctx.input
        jvp = v * 3  # Example JVP computation
        # I compute jvp incorrectly in order to observe
        # the difference with torch.autograd.functional.jvp
        return jvp
    
# Input tensor
input = torch.tensor([1.0, 2.0], requires_grad=True)

myfunc = MyAutogradFunction()

# Apply custom autograd function
output = myfunc.apply(input)

# Compute gradients
output.backward(torch.ones_like(output))

# Suppose you want to compute the JVP, you would use:
v = torch.tensor([0.5, 1.0])  # Vector for JVP
jvp = torch.autograd.functional.jvp(MyAutogradFunction.apply, input, v)
print(output)
print(jvp)
# By computing jvp in this way however the custom jvp seems not to be used

# Attempting to call jvp directly
jvp2 = myfunc.jvp(v)

print(jvp2)

The output I obtain is:

tensor([2., 4.], grad_fn=<MyAutogradFunctionBackward>)
(tensor([2., 4.]), tensor([1., 2.]))
Traceback (most recent call last):
  File "TEST_jvp.py", line 51, in <module>
    jvp2 = myfunc.jvp(v)
TypeError: jvp() missing 1 required positional argument: 'v'

The error occurs when I try to call the jvp method directly using myfunc.jvp(v). The error message indicates that jvp() is missing a required positional argument: v. I believe that the input missing is in fact ctx, so why is this happening? How can I call thejvp method directly to obtain the jacobian-vector product with a random test vector?

Here is the list of the packages installed with the corresponding version:

Package             Version
------------------- ------------
aiohttp             3.9.5
aiosignal           1.3.1
asttokens           2.4.1
async-timeout       4.0.3
attrs               23.2.0
backcall            0.2.0
Brotli              1.1.0
cached-property     1.5.2
certifi             2024.7.4
cffi                1.16.0
cftime              1.6.4
charset-normalizer  3.3.2
colorama            0.4.6
comm                0.2.2
contourpy           1.1.1
cycler              0.12.1
debugpy             1.8.2
decorator           5.1.1
einops              0.8.0
executing           2.0.1
filelock            3.15.4
fonttools           4.53.1
frozenlist          1.4.1
fsspec              2024.6.1
gmpy2               2.1.5
h2                  4.1.0
h5py                3.11.0
hpack               4.0.0
hyperframe          6.0.1
idna                3.7
importlib_metadata  8.0.0
importlib_resources 6.4.0
ipykernel           6.29.5
ipython             8.12.3
jedi                0.19.1
Jinja2              3.1.4
joblib              1.4.2
jupyter_client      8.6.2
jupyter_core        5.7.2
kiwisolver          1.4.5
lightning           2.3.2
lightning-utilities 0.11.3.post0
loguru              0.7.2
markdown-it-py      3.0.0
MarkupSafe          2.1.5
matplotlib          3.7.3
matplotlib-inline   0.1.7
mdurl               0.1.2
meshio              5.3.5
mpi4py              3.1.3
mpmath              1.3.0
msgpack             1.0.8
multidict           6.0.5
munkres             1.1.4
nest-asyncio        1.6.0
netCDF4             1.7.1
networkx            3.1
numpy               1.23.4
packaging           24.1
parso               0.8.4
petsc4py            3.21.2
pexpect             4.9.0
pickleshare         0.7.5
pillow              10.4.0
pip                 24.0
platformdirs        4.2.2
ply                 3.11
pooch               1.8.2
prompt_toolkit      3.0.47
psutil              6.0.0
ptyprocess          0.7.0
pure-eval           0.2.2
pycparser           2.22
Pygments            2.18.0
pyparsing           3.1.2
PyQt5               5.15.9
PyQt5-sip           12.12.2
PySocks             1.7.1
python-dateutil     2.9.0
pytorch-lightning   2.3.3
PyYAML              6.0.1
pyzmq               26.0.3
requests            2.32.3
rich                13.7.1
scikit-learn        1.3.2
scipy               1.10.1
setuptools          70.2.0
sip                 6.7.12
six                 1.16.0
slepc4py            3.21.0
stack-data          0.6.3
sympy               1.12.1
threadpoolctl       3.5.0
toml                0.10.2
tomli               2.0.1
torch               2.1.2
torch_geometric     2.5.2
torch-scatter       2.1.2
torch-sparse        0.6.18
torch-spline-conv   1.2.2
torchmetrics        1.4.0.post0
tornado             6.4.1
tqdm                4.66.4
traitlets           5.14.3
trimesh             4.4.2
triton              2.1.0
typing_extensions   4.12.2
unicodedata2        15.1.0
urllib3             2.2.2
vtk                 9.3.1
wcwidth             0.2.13
wheel               0.43.0
wslink              2.1.1
yarl                1.9.4
zipp                3.19.2
zstandard           0.22.0

Thank you in advance.

When defining custom vjp methods you don’t call them directly, I believe. They should be directly interfacing with torch.autograd.functional or torch.func derivative methods etc.

Hi! Thank you for the answer!
Unfortunately I have tried to use the function torch.autograd.functional.jvp, but it seems that it does not call the custtom jvp I defined, since the print('Using custom jvp') and the error (I have put a 3 instead of a 2) I inserted are not present in the output I provided.
Let me provide some more context. I have an outer nn.Module which takes as input a torch.autograd.Function, as myfunc, but a little more complex. I would like this module to call in its forward method the custom torch.autograd.Function.jvp (having run previously the torch.autograd.Function.apply), therefore I want to find a way to call the jvp method directly, I hoped I could call it like in the example text, but probbably it is not a simple as that (if possible at this point). Any more ideas or comments?