Efficient grad calculation for multioutput/multihead models

I have a graph-neural network (GNN) where at each node there are N independent readout MLPs that each produce a scalar. This means that the model has a large number of shared weights in the shared-backbone and a smaller number of indepedent weights in the readout MLPs. During training (prior to a .backward() call) I need the gradient of the each readout wrt the input, which is an array. I have a naive method that works where I zero grad_outputs except for the head of interest:

batch_size, n_outputs = output.shape
gradients = []

for i in range(n_outputs):
    grad_output = torch.zeros_like(output)
    grad_output[:, i] = 1.0

    # Retain the graph for all outputs except the last one
    retain = True if i < n_outputs - 1 else training

    gradient = torch.autograd.grad(


combined_gradient = torch.stack(gradients, dim=-1)
return -1 * combined_gradient

However, this seems inefficient as weights in the shared-backbone are the same but they might be computed N times. Is there a more efficient way of doing this? For example, first calculating the gradient of the backbone wrt the inputs and then outputs wrt the backbone output? Any insights appreciated!