How to find the order of calls to different forward methods when calling model(input)?

I would like to dynamically replace the forward function of some layers in a neural network with some other forward functions.

Currently, I achieve this by defining forward hooks on the target layers such that the hook function calls the desired forward, which replaces the original output. However, this implementation has two shortcomings. First, both the original and new forward functions are called, which results in slow-down and waste of resources. Second, using hooks seems to slow down the code significantly.

I was thinking that an alternative solution would be to find all the forward functions that are called, replace some of them, and call the new high-level forward function. However, I am not sure how to find the order in which all forward functions are called.

Can you please let me know if there is an easy way of finding all the forward functions that are called with their correct order?

How large is the slow down using hooks?
Would creating custom modules by deriving from the corresponding base class and reimplementing the forward work?
You could monkey-patch the forward method, if you want to change it for all instances of the module, but that’s of course not the cleanest way.

1 Like

I don’t have the exact numbers yet. But on a simple MLP trained on MNIST, setting the batch size to 10,000 leads to a runtime of 25 seconds while setting the batch size to 256 leads to a runtime of 9 minutes. I guess this is mostly due to the overhead of hooks.

That might work. But sometimes I need to replace multiple consecutive layers with a single layer. Currently, in those cases, the forward hook of the first layer in the group stores its inputs, the new layer takes those inputs and processes them, and the forward hook of the last layer in the group modifies the original output to use the new layer’s output.

Ideally, I would like to eliminate all the calls to the old layers to save computation time and achieve speed up.

The timing sounds weird.
Are you sure you are seeing this slowdown from hooks?
I would assume to see some overhead, but yours seems to be really large.

I’m seeing a ~4% slowdown using this dummy code:

model = models.resnet50()
model.cuda()

x = torch.randn(32, 3, 224, 224).cuda()

# warmup
for _ in range(50):
    out = model(x)

nb_iters = 100    
torch.cuda.synchronize()
t0 = time.time()

for _ in range(nb_iters):
    out = model(x)
    
torch.cuda.synchronize()
t1 = time.time()

print('{}s/iter'.format((t1 - t0)/nb_iters))


# Use hooks
def my_hook(m, input, output):
    output = output * 100
    return output
    
for child in model.children():
    child.register_forward_hook(my_hook)


# warmup
for _ in range(50):
    out = model(x)

nb_iters = 100    
torch.cuda.synchronize()
t0 = time.time()

for _ in range(nb_iters):
    out = model(x)
    
torch.cuda.synchronize()
t1 = time.time()

print('{}s/iter'.format((t1 - t0)/nb_iters))
1 Like

Thank you for sharing this snippet of code. It is really helpful.

I profiled the code and it seems that the issue is not due to the hooks, but maybe the function I am replacing. Regardless of whether the batch size is 256 or 10,000, the first layer that I replace takes 12 seconds on GPU and the second one takes about 6 seconds. As a result, the configuration that has a large number of small batches takes a lot longer.

Most of my functions include a single line, which is very long and includes bit-wise operations on bool tensors, e.g.
out[:, 0] = (x[:, 0] & ~x[:, 1]) | (x[:, 124] & (x[:, 118] & ~x[:, 254]))...

In both cases, the GPU utilization is around 20%.