Inference with only certain layers of a model

Hi all,

I have a nn.Module model with pretained weights. I would like to do inference with a subset of the layers. For example, given a model with 5 layers [layer_0, layer_1, layer_2, layer_3, layer_4], I would like to do inference with a subset of the layers, i.e. from layer_1 to layer_3.

The direct but tedious way is to define a new forward() function. I would have to go to the source code and redefine that for every new model I want to use. Are there alternative solutions to do that efficiently?

I have so far tried two solutions, but each fail in some way:

  1. torch.nn.modules.module.register_module_forward_hook Forward hook allows me to get intermediate outputs but it seems to be restricted to forward inference from the first layer of the model, i.e. the input always is expected to start from layer_0.
  2. torch.fx The symbolic tracer is not compatible with many operations not native to torch. For example, einops and **kwargs.

Thanks!

  1. torch.nn.modules.module.register_module_forward_hook Forward hook allows me to get intermediate outputs but it seems to be restricted to forward inference from the first layer of the model, i.e. the input always is expected to start from layer_0.

This sounds like a reasonable approach. Is there a sub-module you can attach the hook to?

Yes, I do have sub-modules. I can get the output from sub-modules using forward hook. Is it possible to use forward hook so that forward inference starts from layer_1 instead of from layer_0?

You can pass in dummy inputs to layer_0 and create the tensors with device=“meta” so that no real compute is actually done. Then in the forward hook of layer_1, swaps out the inputs to itself with the actual inputs.

Thanks for the response! I followed your idea and implemented it. Here is the pseudocode:

# Define the pre-hook
def create_forward_pre_hook(context):
    def forward_pre_hook(module, input):
        return (context['replacement_input'])
    return forward_pre_hook

# Register the hook on block_0
hook_handle = model.blocks[0].register_forward_pre_hook(create_forward_pre_hook(context))

# Inference
for step, data in enumerate(data_loader):
    context['replacement_input'] = data
    dummy_input = torch.ones_like(data, device='cuda')
    output = model(dummy_input)

output is different if I register the hook from if I don’t register the hook, so I believe the hook is indeed replacing the input to block_0.

However, this only works with dummy_input = torch.ones_like(data, device='cuda') and not with dummy_input = torch.ones_like(data, device='meta'). Since my model is on ‘cuda’ and dummy_input is on ‘meta’, torch complains that they are not on the same device. Perhaps I didn’t understand your idea on the ‘meta’ device fully, could you please elaborate on how to use ‘meta’ device to avoid doing real compute on the layers before block_0, for efficiency sake.

Does first moving that first layer to the meta device with .to() work?

Do you mean moving the first layer of my model to meta device? Maybe I need to move all layers before the layer I attach my prehook to meta device? Is there a way to identify all the layers used in forward pass before a particular layer?

A dumb way of doing this is to have the entire model on meta initially and only move everything to non-meta after hitting the prehook.

Is there a way to identify all the layers used in forward pass before a particular layer?

A smarter way may involve (1) setting a global flag upon hitting the prehook (2) attaching a hook that interposes on all layers (3) sending through tensors and seeing what modules we hit prior to the flag being set.

Thanks for the help and suggestions! I will consider using the meta devices later in development.