Gradient of dot product for every index pair

Hello!

I am interested in computing \nabla_{x_i}(s(x_i)^t s(x_j)).

The dot products are computed as product_pairs = s.matmul(s.transpose(-2, -1)).

However, I am unsure about how to proceed for the grads. grad(product_pairs.sum(), x) is not the way to go given that it sums over j as well, and the output has only two free indices (i and dimension d) whereas I expect three (i, j, and d).

Any solutions for this? Thanks!

Are you looking for something like torch.vmap — PyTorch 2.5 documentation