I would like to get the value of f(x) and its first derivative simultaneously. Is there any way to get that now?
Otherwise, the model has to inference twice to get the value of f(x) and first derivative…
Normally, the derivative is calculated after the loss is.
import torch import torch.nn as nn model = nn.Linear(3, 1) inputs = torch.rand(3) targets = torch.rand(1) outputs = model(inputs) criterion = nn.MSELoss() loss = criterion(outputs, targets) loss.backward() print(model.weight.grad) print(model.bias.grad)
Hi, thanks for reply, but I would like to calculate the jacob of a function f(x) by
torch.func.jacrev which is not the loss.
As I can find,
torch.func.jacrev can only return the gradients.
Another solution I can find is to use
torch.autograd.grad instead for each row of
f(x). but it seems a bit slower than
This is possible within the
torch.func namespace, here’s a minimal reproducible example,
x = torch.tensor(5.) #dummy input def f(x): #our function return x**2 def func_with_aux(x): #same *args as f(x) out = f(x) return out, out #return 'out' twice from torch.func import jacrev gradient, output = jacrev(func_with_aux, argnums=(0), has_aux=True)(x) print(gradient, output) # returns 10, 25
Hi @AlphaBetaGamma96 ,
Thanks, that was really helpful.
I encounter a new issue, that I also need to jit script my model, but I got an error below when jitting my model involving
torch.jit.frontend.NotSupportedError: Compiled functions can't take variable number of arguments or use keyword-only arguments with defaults: File "/home/gridsan/ywang3/.conda/envs/rid_openmm/lib/python3.11/site-packages/torch/_functorch/eager_transforms.py", line 354 @exposed_in("torch.func") def jacrev(func: Callable, argnums: Union[int, Tuple[int]] = 0, *, has_aux=False, ~~~~~ <--- HERE chunk_size: Optional[int] = None, _preallocate_and_copy=False):
It’s prabobly related to the arg of jacrev functions, but could you please explain a bit this error? Thanlks!
I don’t believe
torch.script.jit, so you’ll have to find a work-around for this. Why do you need to jit your model?