Early Exiting The Network

Hi, I have a small question, I want to exit my network, but not it the end, let’s say I’m certain in some level that my result is correct after 3 layers, and I don’t want to do the other 2 layers, How can I do it?

I have the condition, and I’m counting the number of time it should exit, But I’m puzzled how to exactly exit the graph, let’s say I have batch of 32 and 10 of them are good enough after 3 layers, how can I stop their processing (in pytorch of course)?

Just to make things clear, I’m not asking about the theory, I’m asking about how to finish processing and adding to the next data in line in pytorch.

Thanks.

Hi,

I guess you have two possibilites here:

  • When a particular element is finished, then remove it from the batch and continue the processing with only the remaining ones. At the end, recreate the output with all the elements. For example, put outputs at the right place in a list (corresponding to the sample position in the batch) and then stack that list along the batch dimension at the end.
  • Process all the inputs all the way to the end of the net and keep all intermediary results then compute the final output with some masking: out = out_it1 * has_stopped_at_it1 + out_it2 * has_stopped_at_it2 + ... where has_stopped_at_it1 is a Tensor containing an entry for each element in the batch with 1 if this sample stopped at iteration 1 and 0 otherwise. and the sum of all these tensors should be a tensor full of 1 (each element should stop at some point).

The second one might be faster for most usecase. The first may win out if you have many iterations and most of your samples actually stop at a very early stage.

1 Like

Thanks for your answer, you definitely helped me with part of the solution.
Now for the second which I think is the easier part.

If I want to exit the network in the middle what should I do?
For simplicity let’s say batch=1. and my data just passed the second layer of the network, and now I want to check, if some condition is correct, if it is I want to stop the processing, else, continue processing.

Now the simplest thing I thought about is to break my network to pieces and do if condition between them, but that looks like a stupid Idea.
because I’m checking the condition only in the end of the network right now, means I’m running till the end, saving states, and then I seeing where I should have exited, But I’m not really exiting.

Am I clear enough? I have a hard time to explain this I know, actually, it’s a similar problem to the network “BranchyNet” if someone knows.

Hi,

I don’t think you can do anything else that checking after every stage and breaking when you’re done. You have to do this for loop anyway, so adding an if statement into it is not a big deal.

1 Like

Got it, thanks for your support