Backpropagate through forward hooks?

Can I backpropagate through forward hooks?

For example:

import torch
import torch.nn as nn

class Net(nn.Module):

	def __init__(self):
		super(Net, self).__init__()

		self.net = nn.Sequential(
			nn.Conv2d(3,32, 3), # layer0
			nn.ReLU(True),
			nn.Conv2d(32,64, 3), # layer1
			nn.ReLU(True),
			nn.Conv2d(64,64, 3), # layer2
			nn.ReLU(True),
			nn.Conv2d(64,1, 1)) # layer3

        self.other_net = nn.Conv2d(64,1,1)

		# Forward hooks to store the outputs
		self.layer_outputs = {}
		self.net[1].register_forward_hook(
			save_outputs(self.layer_outputs, 'layer1'))

	def forward(self, x):
		pred = self.net(x)
                other_pred = self.other_net(self.layer_outputs['layer1'])
		return pred, other_pred

def save_outputs(output_dict, name):
	'''Closure to save the outputs in a forward hook'''
	def hook(self, input, out):
		output_dict[name] = out
	return hook

crit =nn.MSELoss()
model = Net()
x = torch.rand(1,3,4,4)
pred, other_pred = model(x)
loss = crit(pred, torch.ones(1,1,4,4)) + crit(other_pred, torch.ones(1,1,4,4))
loss.backward()

Of course, the gradient from pred, which is predicted from a full pass through the network, contributes to the update. However, does the gradient from other_pred contribute as well? Does a forward hook preserve the gradients so that layer1 is updated by both the loss from pred and other_pred?

1 Like

The gradients will propagate regardless where you put the calculation (well, except with torch.no_grad(): blocks an so).
Personally, I would just spell out the forward instead of using hooks. It’s more readable and probably less code overall, too.

def forward(self, x):
    y = x
    for y in range(2):
      y = self.net[i](y)
    other_pred = self.other_net(y)
    for y in range(2, len(self.net)):
      y = self.net[i](y)
    return y, other_pred

or so. Or you could just split the net into a lower and upper sequential part and call those instead of the for loops:

def forward(self, x):
    y = self.net_lower(x)
    other_pred = self.other_net(y)
    pred = self.net_upper(y)
    return pred, other_pred

Best regards

Thomas

1 Like