Hi @rajoy99,
Expanding on @Grant_Norman’s answer, you should ideally only use torch.func
functions when computing per-sample gradients and not mix torch.func
calls with torch.autograd
calls (at least not for the calculation of the per-sample gradient), that’s why you get a .grad
not required error.
Within the torch.func
namespace when you vectorize over a function, it will no longer see the batch dimension and effectively redefines your function to work for a ‘single’ sample and then computes all these samples in parallel, so you need to remove your indexing on the input Tensor as well (torch.func
can’t see the batch dim!)
Also, as you have noticed the arguments passed to torch.func
calls may have their .grad
flags disabled as it returns the gradients directly rather than populating the .grad
attribute.
To see all the changes between functorch
and torch.func
, have a look at the migrating from functorch
to torch.func
docs here.
Your code above should be re-written to something like this,
net = Model(*args, **kwargs) #nn.Module
params = dict(net.named_parameters()) #params are now dicts in torch.func (rather than lists in functorch)
from torch.func import functional_call, vmap, jacrev, grad
def fcall(params, sample): #functionalized call (with params as input)
return functional_call(net, params, sample)
def compute_loss(params, sample):
u_pred = fcall(params, sample)
u_grads = grad(fcall, argnums=(1))(params, sample) #assume model outputs scalar
#u_t and u_x can be spilt separately (depending on their shape and if they're independent of each other)
u_xx = jacrev(jacrev(fcall, argnums=(1)), argnums=(1))(params, sample) #reverse-over-reverse Hessian calc.
residual = u_t + (u_pred*u_x) - (nu*u_xx) #sum (be careful of the shapes here, you have a matrix and vector and try to avoid broadcasting issues).
HSNR = 0.5 * (residual**2)
return HSNR
losses = vmap(compute_loss, in_dims=(None,0))(params, samples)
Now granted, the example above is for a model with only params
and not buffers
, but changing that should be minimal (I just don’t know how your model is defined, but something like the torch.func.stack_module_state
can give an example of how to combine buffers
with params
when vmap
-ing.
You’ll then need to modify the grad
and jacrev
calls to include the buffers
object (but that won’t be too difficult).