I’m trying to define a custom torch.autograd.Function
with user-defined forward
, setup_context
, backward
, and jvp
methods. I want to have control over the Jacobian-vector product (JVP) and be able to call it explicitly. However, I encountered an issue where the jvp
method raises a TypeError
because it expects a ctx
argument, even though I’m calling it as a static method.
Below is a minimal example to illustrate the problem:
Example Code
import torch
class MyAutogradFunction(torch.autograd.Function):
@staticmethod
def forward(input):
# Perform forward computation
return input * 2 # Example operation
@staticmethod
def setup_context(ctx, inputs, outputs):
input, = inputs
ctx.save_for_backward(input)
#ctx.save_for_forward(input)
ctx.input = input
@staticmethod
def backward(ctx, grad_output):
# Retrieve saved tensors
input, = ctx.saved_tensors
# Compute gradient w.r.t input
grad_input = grad_output * 2 # Example gradient computation
return grad_input
@staticmethod
def jvp(ctx, v):
# Compute the Jacobian-vector product
print('Using custom jvp')
input = ctx.input
jvp = v * 3 # Example JVP computation
# I compute jvp incorrectly in order to observe
# the difference with torch.autograd.functional.jvp
return jvp
# Input tensor
input = torch.tensor([1.0, 2.0], requires_grad=True)
myfunc = MyAutogradFunction()
# Apply custom autograd function
output = myfunc.apply(input)
# Compute gradients
output.backward(torch.ones_like(output))
# Suppose you want to compute the JVP, you would use:
v = torch.tensor([0.5, 1.0]) # Vector for JVP
jvp = torch.autograd.functional.jvp(MyAutogradFunction.apply, input, v)
print(output)
print(jvp)
# By computing jvp in this way however the custom jvp seems not to be used
# Attempting to call jvp directly
jvp2 = myfunc.jvp(v)
print(jvp2)
The output I obtain is:
tensor([2., 4.], grad_fn=<MyAutogradFunctionBackward>)
(tensor([2., 4.]), tensor([1., 2.]))
Traceback (most recent call last):
File "TEST_jvp.py", line 51, in <module>
jvp2 = myfunc.jvp(v)
TypeError: jvp() missing 1 required positional argument: 'v'
The error occurs when I try to call the jvp
method directly using myfunc.jvp(v)
. The error message indicates that jvp()
is missing a required positional argument: v
. I believe that the input missing is in fact ctx
, so why is this happening? How can I call thejvp
method directly to obtain the jacobian-vector product with a random test vector?
Here is the list of the packages installed with the corresponding version:
Package Version
------------------- ------------
aiohttp 3.9.5
aiosignal 1.3.1
asttokens 2.4.1
async-timeout 4.0.3
attrs 23.2.0
backcall 0.2.0
Brotli 1.1.0
cached-property 1.5.2
certifi 2024.7.4
cffi 1.16.0
cftime 1.6.4
charset-normalizer 3.3.2
colorama 0.4.6
comm 0.2.2
contourpy 1.1.1
cycler 0.12.1
debugpy 1.8.2
decorator 5.1.1
einops 0.8.0
executing 2.0.1
filelock 3.15.4
fonttools 4.53.1
frozenlist 1.4.1
fsspec 2024.6.1
gmpy2 2.1.5
h2 4.1.0
h5py 3.11.0
hpack 4.0.0
hyperframe 6.0.1
idna 3.7
importlib_metadata 8.0.0
importlib_resources 6.4.0
ipykernel 6.29.5
ipython 8.12.3
jedi 0.19.1
Jinja2 3.1.4
joblib 1.4.2
jupyter_client 8.6.2
jupyter_core 5.7.2
kiwisolver 1.4.5
lightning 2.3.2
lightning-utilities 0.11.3.post0
loguru 0.7.2
markdown-it-py 3.0.0
MarkupSafe 2.1.5
matplotlib 3.7.3
matplotlib-inline 0.1.7
mdurl 0.1.2
meshio 5.3.5
mpi4py 3.1.3
mpmath 1.3.0
msgpack 1.0.8
multidict 6.0.5
munkres 1.1.4
nest-asyncio 1.6.0
netCDF4 1.7.1
networkx 3.1
numpy 1.23.4
packaging 24.1
parso 0.8.4
petsc4py 3.21.2
pexpect 4.9.0
pickleshare 0.7.5
pillow 10.4.0
pip 24.0
platformdirs 4.2.2
ply 3.11
pooch 1.8.2
prompt_toolkit 3.0.47
psutil 6.0.0
ptyprocess 0.7.0
pure-eval 0.2.2
pycparser 2.22
Pygments 2.18.0
pyparsing 3.1.2
PyQt5 5.15.9
PyQt5-sip 12.12.2
PySocks 1.7.1
python-dateutil 2.9.0
pytorch-lightning 2.3.3
PyYAML 6.0.1
pyzmq 26.0.3
requests 2.32.3
rich 13.7.1
scikit-learn 1.3.2
scipy 1.10.1
setuptools 70.2.0
sip 6.7.12
six 1.16.0
slepc4py 3.21.0
stack-data 0.6.3
sympy 1.12.1
threadpoolctl 3.5.0
toml 0.10.2
tomli 2.0.1
torch 2.1.2
torch_geometric 2.5.2
torch-scatter 2.1.2
torch-sparse 0.6.18
torch-spline-conv 1.2.2
torchmetrics 1.4.0.post0
tornado 6.4.1
tqdm 4.66.4
traitlets 5.14.3
trimesh 4.4.2
triton 2.1.0
typing_extensions 4.12.2
unicodedata2 15.1.0
urllib3 2.2.2
vtk 9.3.1
wcwidth 0.2.13
wheel 0.43.0
wslink 2.1.1
yarl 1.9.4
zipp 3.19.2
zstandard 0.22.0
Thank you in advance.