Use vmap and grad to calculate gradients for one layer independently for each input in batch

Hi,

I am working on a project where I need to calculate gradients for a particular model layer independently for each input in the batch. I have it working well unvectorized, but when I try to vectorize to speed things up I get the following error: “RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn”

Has anyone see this issue before?

Some relevant code:

from torch.autograd import grad

def point_gradients(loss, params=None):
    gradients = grad(loss, params, retain_graph = True,)
    gradients = np.concatenate([g.cpu().numpy().flatten() for g in gradients])
    return(gradients)

if vecorize:
    grad_fn = vmap(point_gradients, in_dims=0)
    batch_gradients = grad_fn(losses, params=layer_params)
else:
    batch_gradients = [point_gradients(loss, params=layer_params) for loss in losses]

If I set vectorize to false, so point_gradients is called on individual losses everything works as expected
but if I set vectorize to True so the tensor of losses is passed to the vmap-ed function, I get the error:
“RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn”

the shape of losses is torch.Size([256, 1]) if that is important.

Thanks for the help!

Hi @jdesmarais,

The reason why you don’t have a .grad attribute is that you’re mixing torch.func with torch.autograd.grad, which must be done with care.

If you want to vectorize your gradient calculation you’ll need to calculate it entirely within the torch.func namespace, so something like this,

def calc_loss(params, x): #and other args if need be
  return loss #insert loss func here

from torch.func import vmap
gradients = vmap(torch.func.grad(calc_loss, argnums=(0)), in_dims=(None,0))(params, x)

then you can concatenate them like you do above

Is there a way to use torch.func.grad to calculate gradients for a specific model layer instead of the entire model? Thanks!