How can I have access to the gradients during backward before they are aggregated to the parameter in Pytorch?
Like when multiple inputs use the same parameter, the gradients are summed up during backward.
For example in the following case how can I get the gradients of the parameters of linear
with respect to each sequence of inputs
loss_fn = nn.CrossEntropyLoss()
inputs = torch.tensor([[[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12]]]).float()
labels = torch.tensor([[3, 6, 0]])
linear = nn.Linear(4, 8)
out = linear(inputs)
loss = loss_fn(out.view(-1, out.shape[-1]), labels.view(-1))
loss.backward()
Just like the way I can calculate them separately here:
grads = []
for i in range(3):
linear.zero_grad()
loss = loss_fn(out[:, i, :].view(-1, out.shape[-1]), labels[:, i].view(-1)) / 3
loss.backward(retain_graph=True)
grads.append(linear.weight.grad.clone())