Ok, let’s get a bit more specific:
>>> print(grad_dict["fc2"]["grad_input"][0][0].size())
torch.Size([4, 30])
>>> print(grad_dict["fc2"]["grad_input"][0][1].size())
torch.Size([4, 20])
>>> print(grad_dict["fc2"]["grad_input"][0][2].size())
torch.Size([20, 30])
>>> print(grad_dict["fc2"]["grad_output"][0][0].size())
torch.Size([4, 30])
grad_input
is a 3-tuple, my guess:
- [0] is the derivative of loss wrt layer input
- [1] is the derivative of loss wrt layer output (before activation)
- [2] is the derivative of loss wrt layer weights
grad_output
is a 1-tuple, perhaps it’s the derivative of loss wrt layer output after activation?
Please correct me if I am wrong.