Hi @jdesmarais,
The reason why you don’t have a .grad
attribute is that you’re mixing torch.func
with torch.autograd.grad
, which must be done with care.
If you want to vectorize your gradient calculation you’ll need to calculate it entirely within the torch.func
namespace, so something like this,
def calc_loss(params, x): #and other args if need be
return loss #insert loss func here
from torch.func import vmap
gradients = vmap(torch.func.grad(calc_loss, argnums=(0)), in_dims=(None,0))(params, x)
then you can concatenate them like you do above