Unrolling the model graph in a static fashion

I’m using pytorch on TPUs, and wish to implement an early exit for my layers to stop execution.

Say for simplicity’s sake, I have a single block - myblock() that I want to forward through multiple times:

iters = randrange(1,10)
input : torch.tensor = ...

outputs  = ... # preallocate `iters` lenght tensor, to be filled at each iteration
for i in range(iters):
    interim_tensor = myblock(input)
    outputs[i] = interim_tensor

if is_training:
    return interim_tensor

if is_inferring:
    return outputs # we want a log of intermediate results too, just in case

Now, because iters is stochastic, this graph is generated dynamically.

However, on TPUs this might trigger a recompilation everytime iters changes slowing down the training.

So how do I convert this to a static graph? I do know the maximum number of iterations I can use - defined through randrange.

So perhaps I can forward through every layer and extract ith tensor I want? I could use torch.where but I’m worried I’ll be proprogating the wrong gradients elsewhere since those ops would be recorded on the graph too - whereas I want them on the graph (so it remains static) just that they shouldn’t have any effect on the backprop because I used an intermediate tensor instead.

Anyone have any ideas?

Anyone any ideas :pray: ? could use some help here