Gradient of a mixed network's output with respect to ONE tensor

Suppose I have an MLP: R^d → R. Since MLP(x_i) does not depend on x_j for j != i, I can retrieve the gradient by simply calling grad(output.sum(), x).

However, I am now in the case where my network is a GNN, so the above trick doesn’t work anymore (because the computation for each x_i depends on x_j for all j). Again, my output is a real number for each x_i, and I want to compute its gradient with respect to x_i.

Any ideas? Thanks!

Hmm I feel like I don’t completely understand. torch.autograd.grad(output.sum(), x) would compute the gradient of the output wrt x regardless of the dependency structure of the function. Or are you saying the inputs depend on one another?

The case in which that trick works is the following grad_xi (sum_j fj(xj) ) = grad_xi fi(xi), but in my case f is fi = fi(x1, x2, …).