Using forward hooks vs Modifying the forward function

Hi,

When we need to modify input to and output from forward function of a layer, I can think of two ways:

  1. Add forward_pre and forward hooks
  2. Modify the forward function of the layer itself.

Below are simple illustration of each of these.

In my opinion the two are exactly equivalent. Anything that can be done using 1 can be done using 2 and vice versa.

My question: Is this understanding correct? Is there are any reason to prefer one over the other? Either for performance or as general good practice.

Demo:

  1. Using hooks
from torch import nn

################### Using forward hooks #########################

## Setup
conv1 = nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2).to('cuda:0')
inp = torch.rand((1, 3, 299, 299)).to('cuda:0')

## pre-forward hook
def print_pre(module, input):
    print("I am about to start the forward run!")
# # post-forward hook  
def print_post(module ,input, output):
    print("I am done with the forward")

## Register the hooks
conv1.register_forward_pre_hook(print_pre)
conv1.register_forward_hook(print_post)

## forward run
out1 = conv1(inp)
  1. Using wrapper
############### Using forward Wrapper #################
from torch import nn

conv2 = nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2).to('cuda:0')
inp = torch.rand((1, 3, 299, 299)).to('cuda:0')

## save original forward function
original_forward = conv2.forward

## define a wrapper around the original forward function
def modified_forward(module, *input):
    print("I am about to start the forward run!")
    out = original_forward(module,*input)
    print("I am done with the forward")
    return out

## set wrapper as the new forward 
conv2.forward = modified_forward

## forward run
out2 = conv2(inp)

Thank you!

Yes, you are correct that both approaches should yield the same results.
Using hooks could be easier in case you want to reuse a predefined model and don’t want to override the forward method (e.g torchvision.models.resnet50()), but this also depends on your personal coding style. E.g. I would prefer to override the forward method in case I really want to manipulate the forward pass and not only use the hooks for debugging or a quick testing.

Additionally, I think hooks are not supported in scripted models (feature request tracked here) so in case you would like to script the model in the future, I would also override the forward method.

3 Likes

Thanks you @ptrblck ! That makes sense.