I’m using pytorch on TPUs, and wish to implement an early exit for my layers to stop execution.
Say for simplicity’s sake, I have a single block - myblock()
that I want to forward through multiple times:
iters = randrange(1,10)
input : torch.tensor = ...
outputs = ... # preallocate `iters` lenght tensor, to be filled at each iteration
for i in range(iters):
interim_tensor = myblock(input)
outputs[i] = interim_tensor
if is_training:
return interim_tensor
if is_inferring:
return outputs # we want a log of intermediate results too, just in case
Now, because iters
is stochastic, this graph is generated dynamically.
However, on TPUs this might trigger a recompilation everytime iters
changes slowing down the training.
So how do I convert this to a static graph? I do know the maximum number of iterations I can use - defined through randrange
.
So perhaps I can forward through every layer and extract i
th tensor I want? I could use torch.where
but I’m worried I’ll be proprogating the wrong gradients elsewhere since those ops would be recorded on the graph too - whereas I want them on the graph (so it remains static) just that they shouldn’t have any effect on the backprop because I used an intermediate tensor instead.
Anyone have any ideas?