Hello!
I am interested in computing \nabla_{x_i}(s(x_i)^t s(x_j))
.
The dot products are computed as product_pairs = s.matmul(s.transpose(-2, -1))
.
However, I am unsure about how to proceed for the grads. grad(product_pairs.sum(), x)
is not the way to go given that it sums over j
as well, and the output has only two free indices (i
and dimension d
) whereas I expect three (i
, j
, and d
).
Any solutions for this? Thanks!