If one defines a model f
, and x
is training data, then we have f(x)
.
How can one concisely define a model f(w)
as a function of its parameters w
? Will PyTorch be able to differentiate it wrt w
, like JAX/STAX can?
Could you explain this use case a bit or post some dummy code what should be achieved?
Basically, I just want to linearize a model with respect to its weights to have f(w) = f(w0) + jac(w0)@(w-w0)
when x
is fixed.
JAX, for example, already has Jacobian and Hessian functions, and it’s quite low-lever, so it allows to set weights directly.
Given a model (inherited from nn.Module), calculate Jacobian of parameters of this model with respect to output vector.
This looks like a hack to do this,
Hi, I see that you are using the unsafe hack I previously proposed.
Just want to notice you that there seems to be a new elegant way to do this.
Please take a look at NN Module functional API · Issue #49171 · pytorch/pytorch · GitHub
which shall provide an elegant solution after 1.11 release.