Can I backpropagate through forward hooks?
For example:
import torch
import torch.nn as nn
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.net = nn.Sequential(
nn.Conv2d(3,32, 3), # layer0
nn.ReLU(True),
nn.Conv2d(32,64, 3), # layer1
nn.ReLU(True),
nn.Conv2d(64,64, 3), # layer2
nn.ReLU(True),
nn.Conv2d(64,1, 1)) # layer3
self.other_net = nn.Conv2d(64,1,1)
# Forward hooks to store the outputs
self.layer_outputs = {}
self.net[1].register_forward_hook(
save_outputs(self.layer_outputs, 'layer1'))
def forward(self, x):
pred = self.net(x)
other_pred = self.other_net(self.layer_outputs['layer1'])
return pred, other_pred
def save_outputs(output_dict, name):
'''Closure to save the outputs in a forward hook'''
def hook(self, input, out):
output_dict[name] = out
return hook
crit =nn.MSELoss()
model = Net()
x = torch.rand(1,3,4,4)
pred, other_pred = model(x)
loss = crit(pred, torch.ones(1,1,4,4)) + crit(other_pred, torch.ones(1,1,4,4))
loss.backward()
Of course, the gradient from pred
, which is predicted from a full pass through the network, contributes to the update. However, does the gradient from other_pred
contribute as well? Does a forward hook preserve the gradients so that layer1
is updated by both the loss from pred
and other_pred
?