Gradient of a matrix with respect to its rows

Hey guys
I have a matrix M(X) with shape (N, F), where X is my dataset, also (N, F). I am interested in calculating the gradient of M_ij with respect to X_i.
Currently, in Pytorch, I am doing it this way:

grad = torch.stack([
    torch.stack([
        torch.autograd.grad(M[i,j], X, retain_graph=True)[0][i] for j in range(X.shape[1])
    ]) for i in range(X.shape[0])
])

This gives back a matrix whose shape is (N, F, F).
However, this is painfully slow. Can you guys think of a way to vectorize this?

Hi @sherlock.h,

You’ll want to have a look at the torch.func namespace. You can compose torch.func.grad and torch.func.grad to efficiently vectorized over your gradient computation.

The docs for torch.func can be found here: torch.func — PyTorch 2.3 documentation