Use vmap and grad to calculate gradients for one layer independently for each input in batch

Hi @jdesmarais,

The reason why you don’t have a .grad attribute is that you’re mixing torch.func with torch.autograd.grad, which must be done with care.

If you want to vectorize your gradient calculation you’ll need to calculate it entirely within the torch.func namespace, so something like this,

def calc_loss(params, x): #and other args if need be
  return loss #insert loss func here

from torch.func import vmap
gradients = vmap(torch.func.grad(calc_loss, argnums=(0)), in_dims=(None,0))(params, x)

then you can concatenate them like you do above