Stop model evaluation with forward hook

Let’s say I have a model that I only want to run until the 8-th layer. I register a forward hook on that layer to save the output, and all works well. The problem is that the network will still run until completion, which makes things slower than necessary.

Is there a way to stopping it early? Or should I raise an exception in the forward hook and catch it outside, to get out of the network forward function?

1 Like

Do you have any control over the architecture of the network? In the python script for that network I imagine you could just delete the lines after the layer whose output you want.

Sure, but I want to do this automatically, without manually deleting lines of code every time :slight_smile:


You could add an argument to the forward method of your model.

def forward(self, input, num_layers):
    # automatically use only num_layers

Yes, sure. But then I would need to manually modify the forward method of every class I am interested in. What I want is to have a layer of abstraction; I don’t care what module it is or how it’s defined, I want to stop execution after a certain forward hook is triggered.


One way to do this is to subclass nn.Sequential and change the forward method so that you can use a number to determine how many layers deep it goes. Then, you can build your models using your modified version of nn.Sequential to do this.

1 Like

Again, thanks for the suggestion but this is not what I am looking for. I want to take in input a model, defined by someone else in some way that I don’t care about, and stop execution at the n-th module, as returned by model.modules(). I don’t want to modify the model definition, or do anything that would imply having knowledge of how the model is defined or how it works.

In short I want it to be completely automated and work on any possible model without any kind of modification on my part :slight_smile:

1 Like

You need to define more specifically what you want. When you say a “module”, do you mean things in nn.* like linear layers nn.Linear, conv layers nn.Conv*d, upsampling nn.Upsample, nonlinearities, etc.? Or do you mean PyTorch Functions, which are the real fundamental ops of the computation graph, e.g. view, thnn_convolution, svd, transpose, matmul, etc. Each of the former class is usually implemented with one or more Functions from the latter class, so you need to be clear what your notion of n-th “module” means. If you are referring to the former class, which is probably more reasonable because they are higher-level and more intuitive, there is another difficulty. People can write models using lower-level Functions, e.g. instead of using nn.Linear they use a transpose + addmm, use functions, such as linalg ops including trtrs and potrs, which do not have the corresponding notion of a higher level nn.* module.

As I mentioned before, I will consider a “module” everything that is returned by model.modules(). They are also the only thing you can attach forward hooks to.

Some people may decide to write their models in a different way, such that model.modules() will be empty or missing pieces, as you point out. But since there is no way around it, I will accept this limitation.

Essentially I want this:

for idx, md in model.named_modules(): 
    if idx == idx_to_stop_execution_to:
_ = model(inputs)

and my_function is defined in such a way that it will save the output of that module, and then stop the forward pass of model to save time. I tried to raise an exception in my_function, but the outputs were not being correctly saved.

*** EDIT***

It turns out, the error was that I forgot to de-register the forward hooks. The code below works as intended.

for idx, md in model.named_modules(): 
    if idx == idx_to_stop_execution_to:
try: _ = model(inputs)
except CustomException: pass
1 Like

This is still ill-defined. What if a module contains Sequential(linear1, linear2)? All three will show up in model.modules(), but one of them contains the other two. What do you mean by n-th module?

This won’t work for a bunch of models, e.g.

def forward(self, x):
  for _ in range(100):
    x = self.fc(x)
  return x
def forward(self, x):
  if[0] < 1:
    return self.fc1(x)
    return self.fc2(x)
  return x
def forward(self, x):
  z = self.subnet(Variable(torch.randn(1, 2))
  return z + x

No where guarantees that named_modules or modules returns in certain order. It doesn’t even work for sequential modules.

If you want to limit youself to nn.Sequential, fine, there are a number of ways to do this. But if you want to do it for general modules. I don’t see how any of the proposals you made could work. You should think about what exactly is a “module”, and with dynamic python flow control, how to get the n-th actually executed “module”.

1 Like

You make some very good points; especially about sequential and the possibility of each module showing up twice.

Note though that

1 - My main question was “stop execution when forward hook is triggered”. So essentially everywhere you want to put a forward hook in, stop execution there. The idea of having it stopped at the “n-th” module was just for my convenience, so that I don’t have to inspect any piece of code.

2 - While there are significant pitfalls to be aware of (and I thank you for pointing them out!), these are left to the “user”. You put in input a model, and a index n; you will get in output the result of registering a forward hook to the “n-th” module returned by model.modules(). It is up to the user then to identify which index is correct, whether a layer shows up twice, or whether it just doesn’t make any sense (as in your last examples). This is just a utility that is useful in a large enough number of cases, and for the cases where it does not make sense, then it does not make sense. The only thing that is required for this to work is that model.modules() is deterministic; i.e. given the same model definition, the modules are returned in the same order. If this holds, then the user would put the correct n, and the method will work (if applicable).

What do you think? Thanks for the answer anyhow :slight_smile:

Given the limitations, it might not be very helpful in many cases. But a way to do this may be registering a hook on every submodule, where this hook increments a global counter and if the counter reaches N saves the result at a global ptr and throws an exception. It’s very very hacky though :smiley:

1 Like

@SimonW I am implementing ManifoldMixup ( which is similar to Mixup regularization technique except it works at layer level instead of input level as in Mixup. I need a similar functionality to implement this,

  1. Select a random index and apply forward hook to that layer
  2. Forward pass using data input x_0 and record output at hooked layer
  3. Use this output along with new input x_1 by adding new hook at the same layer to do this mixup operation

It would be nice if I can stop processing model’s forward pass at hook in first step. Is there a better way to do this now since it has been over a year after this thread?