I tested register_backward_hook on nn.Sequential as below.
import torch
import torch.nn as nn
from torch.autograd import Variable
a = nn.Sequential(nn.Linear(5,3), nn.Tanh(), nn.Linear(3,2))
def hookFunc(module, gradInput, gradOutput):
print(len(gradInput))
for v in gradInput:
print v
a.register_backward_hook(hookFunc)
input = Variable(torch.randn(4,5))
output = a(input)
target = torch.FloatTensor(4,2).fill_(1)
output.backward(target)
The output is as follows.
3
Variable containing:
-0.1122 0.1216 0.7935
-0.1122 0.1216 0.7935
-0.1122 0.1216 0.7935
-0.1122 0.1216 0.7935
[torch.FloatTensor of size 4x3]
Variable containing:
-0.5910 -0.7340 -0.4239
-0.5910 -0.7340 -0.4239
[torch.FloatTensor of size 2x3]
Variable containing:
4
4
[torch.FloatTensor of size 2]
So, it seems that when using register_backward_hook on nn.Sequential, only the gradient related values on the last element of nn.Sequential are returned.
I wonder if this is intended one or not. To get the gradient values for the specific element, should I hook with specifying that element rather than specifying nn.Sequential module?
Another question is, I wonder why the description on http://pytorch.org/docs/_modules/torch/nn/modules/module.html#Module.register_backward_hook says that the hook shouldn’t modify the arguments (gradInput for example.) Are there any reason? Then, what would be a better way to manually modify gradient inputs that would be backward passed to the previous module?
Yeah it’s a known bug (GitHub issue), but it’s on hold because of the large autograd refactor going on right now. Sorry for that.
Yes, you should never modify any arguments given to the hook in-place. If you want to replace grad input, you can do out-of-place operations on it, and return new values from the hook.
@apaszke The example you showed seems for a Variable not a Module. Is there any way I can do simillary on a Module?
What I actually want to do is modifying the input gradients that would be backward passed to the previous modules so that I can do adversarial training.
In Torch7, for example, GradientReversal module modifies the gradInput by mutlplying -1 (not gradients for the module’s weight updates).
For a Linear module with parameters W and b, grad_i will be a tuple that includes gradient over input, W and b. So should we instead use the following function?