I would like to get the value of f(x) and its first derivative simultaneously. Is there any way to get that now?
Otherwise, the model has to inference twice to get the value of f(x) and first derivative…
Normally, the derivative is calculated after the loss is.
import torch
import torch.nn as nn
model = nn.Linear(3, 1)
inputs = torch.rand(3)
targets = torch.rand(1)
outputs = model(inputs)
criterion = nn.MSELoss()
loss = criterion(outputs, targets)
loss.backward()
print(model.weight.grad)
print(model.bias.grad)
Hi, thanks for reply, but I would like to calculate the jacob of a function f(x) by torch.func.jacrev
which is not the loss.
As I can find, torch.func.jacrev
can only return the gradients.
Another solution I can find is to use torch.autograd.grad
instead for each row of f(x)
. but it seems a bit slower than jacrev
Hi @Yanze_Wang,
This is possible within the torch.func
namespace, here’s a minimal reproducible example,
x = torch.tensor(5.) #dummy input
def f(x): #our function
return x**2
def func_with_aux(x): #same *args as f(x)
out = f(x)
return out, out #return 'out' twice
from torch.func import jacrev
gradient, output = jacrev(func_with_aux, argnums=(0), has_aux=True)(x)
print(gradient, output) # returns 10, 25
Hi @AlphaBetaGamma96 ,
Thanks, that was really helpful.
I encounter a new issue, that I also need to jit script my model, but I got an error below when jitting my model involving jacrev
.
torch.jit.frontend.NotSupportedError: Compiled functions can't take variable number of arguments or use keyword-only arguments with defaults:
File "/home/gridsan/ywang3/.conda/envs/rid_openmm/lib/python3.11/site-packages/torch/_functorch/eager_transforms.py", line 354
@exposed_in("torch.func")
def jacrev(func: Callable, argnums: Union[int, Tuple[int]] = 0, *, has_aux=False,
~~~~~ <--- HERE
chunk_size: Optional[int] = None,
_preallocate_and_copy=False):
It’s prabobly related to the arg of jacrev functions, but could you please explain a bit this error? Thanlks!
Hi @Yanze_Wang,
I don’t believe torch.func
supports torch.script.jit
, so you’ll have to find a work-around for this. Why do you need to jit your model?