I am wondering if there is a standard way to know where do the gradients come from, when they add up to make the gradient for my torch.nn.Parameter
during backward()
. I have a custom recurrent network where the same parameters are reused inside the temporal sequence of steps I take during forward()
, thus when calculating the gradients, those come from multiple items in the sequence and if I understand correctly, they are summed up to give me my parameter.grad
after backward()
. Is it possible to determine where do those contributions come from?
Could you please let me know if I understood you correctly when I say:
You are trying to calculate the gradient of the loss with respect to, let’s say, each token in an input sentence separately.
I have worked on such a problem previously using backward hooks but while working with a BERT based model, not RNNs. So maybe, backward hooks could help if that’s what you are trying to do.