Access gradients before they are aggregated to the node

How can I have access to the gradients during backward before they are aggregated to the parameter in Pytorch?

Like when multiple inputs use the same parameter, the gradients are summed up during backward.
For example in the following case how can I get the gradients of the parameters of linear with respect to each sequence of inputs

loss_fn = nn.CrossEntropyLoss()

inputs = torch.tensor([[[1, 2, 3, 4],
                      [5, 6, 7, 8],
                      [9, 10, 11, 12]]]).float()

labels = torch.tensor([[3, 6, 0]])

linear = nn.Linear(4, 8)

out = linear(inputs)

loss = loss_fn(out.view(-1, out.shape[-1]), labels.view(-1))
loss.backward()

Just like the way I can calculate them separately here:

grads = []
for i in range(3):
    linear.zero_grad()
    loss = loss_fn(out[:, i, :].view(-1, out.shape[-1]), labels[:, i].view(-1)) / 3
    loss.backward(retain_graph=True)
    grads.append(linear.weight.grad.clone())

You might be looking for Per-sample-gradients — PyTorch Tutorials 2.1.0+cu121 documentation

1 Like