When registering two nn.Sigmoid() backward hooks that print saved activations from a forward pass, saved in a dict
, I encounter no problems. However, as soon as I make one of the backward hooks return
a new grad_input
, the other backward hook gives an error when trying to access the activation dict. From what does this error stem? And how should I access forward activations, if I cannot use a dict
?
Network
class HookNet(nn.Module):
def __init__(self):
super(HookNet, self).__init__()
self.fc1 = nn.Linear(2,2)
self.s1 = nn.Sigmoid()
self.fc2 = nn.Linear(2,1)
self.s2 = nn.Sigmoid()
self.fc1.weight = torch.nn.Parameter(torch.Tensor([[1, 2],[-1, 2]]))
self.fc1.bias = torch.nn.Parameter(torch.Tensor([0]))
self.fc2.weight = torch.nn.Parameter(torch.Tensor([[1, 2]]))
self.fc2.bias = torch.nn.Parameter(torch.Tensor([0]))
def forward(self, x):
x= self.fc1(x)
x = self.s1(x)
x= self.fc2(x)
x = self.s2(x)
return x
hooknet = HookNet()
Saving forward activations and printing them in the backward pass
saved_activations = {}
def forward_save_act(name, module, input, output):
saved_activations[name] = (input[0].data, output.data)
def backward_use_act(name, module, grad_input, grad_output):
print('___Backward pass for '+str(name)+'___')
input, output = saved_activations[name]
print('Saved Input: '+str(input))
print('Saved Output: '+str(output))
print('Grad_input needed to be overwritten: '+str(grad_input))
new_grad_input = output/input
grad_tuple = (new_grad_input.data*grad_output[0],)
print('New grad input: '+str(grad_tuple))
from functools import partial
for name, m in hooknet.named_modules():
if type(m) == nn.Sigmoid:
m.register_forward_hook(partial(forward_save_act, name))
m.register_backward_hook(partial(backward_use_act, name))
This yields good results, and during the backward pass, the hooks are not troubled accessing saved_activations
and printing them. An example:
inp = torch.Tensor([1, 1])
inp.requires_grad=True
out = hooknet(inp)
out.backward()
Output: [0.91794074]
Backward pass for s2
Saved Input: tensor([2.4147])
Saved Output: tensor([0.9179])
Grad_input needed to be overwritten: (tensor([0.0753]),)
New grad input: (tensor([0.3801]),)
Backward pass for s1
Saved Input: tensor([3., 1.])
Saved Output: tensor([0.9526, 0.7311])
Grad_input needed to be overwritten: (tensor([0.0034, 0.0296]),)
New grad input: (tensor([0.0239, 0.1101]),)
However, if I return the grad_input calculated, I encounter problems for the backward pass for s1. The interesting thing is, that the error is affiliated with the loading from the dict
.
def backward_use_act(name, module, grad_input, grad_output):
print('___Backward pass for '+str(name)+'___')
input, output = saved_activations[name]
print('Saved Input: '+str(input))
print('Saved Output: '+str(output))
print('Grad_input needed to be overwritten: '+str(grad_input))
new_grad_input = output/input
grad_tuple = (new_grad_input.data*grad_output[0],)
print('New grad input: '+str(grad_tuple))
return(grad_tuple)
from functools import partial
for name, m in hooknet.named_modules():
if type(m) == nn.Sigmoid:
m.register_forward_hook(partial(forward_save_act, name))
m.register_backward_hook(partial(backward_use_act, name))
inp = torch.Tensor([1, 1])
inp.requires_grad=True
out = hooknet(inp)
out.backward()
Backward pass for s2
Saved Input: tensor([2.4147])
Saved Output: tensor([0.9179])
Grad_input needed to be overwritten: (tensor([0.0753]),)
New grad input: (tensor([0.3801]),)
Traceback (most recent call last):File “”, line 49, in
out.backward()File “/Users/Eigil/opt/anaconda3/lib/python3.7/site-packages/torch/tensor.py”, line 195, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)File “/Users/Eigil/opt/anaconda3/lib/python3.7/site->packages/torch/autograd/init.py”, line 99, in backward
allow_unreachable=True) # allow_unreachable flagFile “”, line 28, in backward_use_act
print(‘Backward pass for '+str(name)+'’)SystemError: <class ‘str’> returned a result with an error set
If I hardcode the grad_tuple
s rather than loading them from the dict
(or pass them as an additional argument through partial
) I encounter no problems.