I want to compute the gradient of a tensor A of shape (batch_size, M) w.r.t. B of shape (batch_size, L) efficiently, so basically the Jacobian dA/dB of shape (batch_size, M, L) ? While the torch.autograd.grad function only accepts scalar outputs, iterating over each element of A and computing the gradient of a_ij w.r.t. B seems quite inefficient. Also, for my use case, finding an analytic solution seems to be hard.
The simplest case of the problem would be something like
import torch
B = torch.rand((4, 10), requires_grad=True)
W = torch.rand((20, 10))
A = torch.matmul(B, W.T)
torch.autograd.grad(A, B)
Is there a way to compute dA/dB without iteration?
I think A.backward(gradient=torch.ones_like(A)) is the same as A.sum().backward(). Consequently, you don’t compute the derivative of every element of A w.r.t. every element of B, but the derivative of the batch-wise sum of B w.r.t. A. Consequently, you get a gradient of shape (4, 10), not (4, 20, 10).