If one defines a model `f`

, and `x`

is training data, then we have `f(x)`

.

How can one concisely define a model `f(w)`

as a function of its parameters `w`

? Will PyTorch be able to differentiate it wrt `w`

, like JAX/STAX can?

Could you explain this use case a bit or post some dummy code what should be achieved?

Basically, I just want to linearize a model with respect to its weights to have `f(w) = f(w0) + jac(w0)@(w-w0)`

when `x`

is fixed.

JAX, for example, already has Jacobian and Hessian functions, and it’s quite low-lever, so it allows to set weights directly.

Given a model (inherited from nn.Module), calculate Jacobian of parameters of this model with respect to output vector.

This looks like a hack to do this,

Hi, I see that you are using the unsafe hack I previously proposed.

Just want to notice you that there seems to be a new elegant way to do this.

Please take a look at NN Module functional API · Issue #49171 · pytorch/pytorch · GitHub

which shall provide an elegant solution after 1.11 release.