Is it possible to get the value of f(x) while using torch.func.jacrev?

I would like to get the value of f(x) and its first derivative simultaneously. Is there any way to get that now?

Otherwise, the model has to inference twice to get the value of f(x) and first derivative…

Normally, the derivative is calculated after the loss is.

import torch
import torch.nn as nn

model = nn.Linear(3, 1)

inputs = torch.rand(3)
targets = torch.rand(1)

outputs = model(inputs)

criterion = nn.MSELoss()

loss = criterion(outputs, targets)

Hi, thanks for reply, but I would like to calculate the jacob of a function f(x) by torch.func.jacrev which is not the loss.
As I can find, torch.func.jacrev can only return the gradients.
Another solution I can find is to use torch.autograd.grad instead for each row of f(x). but it seems a bit slower than jacrev

Hi @Yanze_Wang,

This is possible within the torch.func namespace, here’s a minimal reproducible example,

x = torch.tensor(5.) #dummy input

def f(x): #our function
  return x**2

def func_with_aux(x): #same *args as f(x)
  out = f(x)
  return out, out #return 'out' twice

from torch.func import jacrev
gradient, output = jacrev(func_with_aux, argnums=(0), has_aux=True)(x)
print(gradient, output) # returns 10, 25

Hi @AlphaBetaGamma96 ,
Thanks, that was really helpful.
I encounter a new issue, that I also need to jit script my model, but I got an error below when jitting my model involving jacrev.

torch.jit.frontend.NotSupportedError: Compiled functions can't take variable number of arguments or use keyword-only arguments with defaults:
  File "/home/gridsan/ywang3/.conda/envs/rid_openmm/lib/python3.11/site-packages/torch/_functorch/", line 354
def jacrev(func: Callable, argnums: Union[int, Tuple[int]] = 0, *, has_aux=False,
                                                                           ~~~~~ <--- HERE
           chunk_size: Optional[int] = None,

It’s prabobly related to the arg of jacrev functions, but could you please explain a bit this error? Thanlks!

Hi @Yanze_Wang,

I don’t believe torch.func supports torch.script.jit, so you’ll have to find a work-around for this. Why do you need to jit your model?