I have a graph-neural network (GNN) where at each node there are `N`

independent readout MLPs that each produce a scalar. This means that the model has a large number of shared weights in the shared-backbone and a smaller number of indepedent weights in the readout MLPs. During training (prior to a `.backward()`

call) I need the gradient of the each readout wrt the input, which is an array. I have a naive method that works where I zero `grad_outputs`

except for the head of interest:

```
batch_size, n_outputs = output.shape
gradients = []
for i in range(n_outputs):
grad_output = torch.zeros_like(output)
grad_output[:, i] = 1.0
# Retain the graph for all outputs except the last one
retain = True if i < n_outputs - 1 else training
gradient = torch.autograd.grad(
outputs=output,
inputs=input,
grad_outputs=grad_output,
retain_graph=retain,
create_graph=training,
allow_unused=True,
)[0]
gradients.append(gradient)
combined_gradient = torch.stack(gradients, dim=-1)
return -1 * combined_gradient
```

However, this seems inefficient as weights in the shared-backbone are the same but they might be computed `N`

times. Is there a more efficient way of doing this? For example, first calculating the gradient of the backbone wrt the inputs and then outputs wrt the backbone output? Any insights appreciated!