Suppose I have an MLP: R^d → R. Since MLP(x_i) does not depend on x_j for j != i, I can retrieve the gradient by simply calling grad(output.sum(), x)
.
However, I am now in the case where my network is a GNN, so the above trick doesn’t work anymore (because the computation for each x_i
depends on x_j
for all j
). Again, my output is a real number for each x_i
, and I want to compute its gradient with respect to x_i
.
Any ideas? Thanks!