I’m trying to define a custom torch.autograd.Function
with user-defined forward
, setup_context
, backward
, and jvp
methods. I want to have control over the Jacobian-vector product (JVP) and be able to call it explicitly. However, I encountered an issue where the jvp
method raises a TypeError
because it expects a ctx
argument, even though I’m calling it as a static method.
Below is a minimal example to illustrate the problem:
Example Code
import torch
class MyAutogradFunction(torch.autograd.Function):
def forward(input):
# Perform forward computation
return input * 2 # Example operation
def setup_context(ctx, inputs, outputs):
input, = inputs
ctx.input = input
def backward(ctx, grad_output):
# Retrieve saved tensors
input, = ctx.saved_tensors
# Compute gradient w.r.t input
grad_input = grad_output * 2 # Example gradient computation
return grad_input
def jvp(ctx, v):
# Compute the Jacobian-vector product
print('Using custom jvp')
input = ctx.input
jvp = v * 3 # Example JVP computation
# I compute jvp incorrectly in order to observe
# the difference with torch.autograd.functional.jvp
return jvp
# Input tensor
input = torch.tensor([1.0, 2.0], requires_grad=True)
myfunc = MyAutogradFunction()
# Apply custom autograd function
output = myfunc.apply(input)
# Compute gradients
# Suppose you want to compute the JVP, you would use:
v = torch.tensor([0.5, 1.0]) # Vector for JVP
jvp = torch.autograd.functional.jvp(MyAutogradFunction.apply, input, v)
# By computing jvp in this way however the custom jvp seems not to be used
# Attempting to call jvp directly
jvp2 = myfunc.jvp(v)
The output I obtain is:
tensor([2., 4.], grad_fn=<MyAutogradFunctionBackward>)
(tensor([2., 4.]), tensor([1., 2.]))
Traceback (most recent call last):
File "TEST_jvp.py", line 51, in <module>
jvp2 = myfunc.jvp(v)
TypeError: jvp() missing 1 required positional argument: 'v'
The error occurs when I try to call the jvp
method directly using myfunc.jvp(v)
. The error message indicates that jvp()
is missing a required positional argument: v
. I believe that the input missing is in fact ctx
, so why is this happening? How can I call thejvp
method directly to obtain the jacobian-vector product with a random test vector?
Thank you in advance.