I’m using pytorch on TPUs, and wish to implement an early exit for my layers to stop execution.
Say for simplicity’s sake, I have a single block -
myblock() that I want to forward through multiple times:
iters = randrange(1,10) input : torch.tensor = ... outputs = ... # preallocate `iters` lenght tensor, to be filled at each iteration for i in range(iters): interim_tensor = myblock(input) outputs[i] = interim_tensor if is_training: return interim_tensor if is_inferring: return outputs # we want a log of intermediate results too, just in case
iters is stochastic, this graph is generated dynamically.
However, on TPUs this might trigger a recompilation everytime
iters changes slowing down the training.
So how do I convert this to a static graph? I do know the maximum number of iterations I can use - defined through
So perhaps I can forward through every layer and extract
ith tensor I want? I could use
torch.where but I’m worried I’ll be proprogating the wrong gradients elsewhere since those ops would be recorded on the graph too - whereas I want them on the graph (so it remains static) just that they shouldn’t have any effect on the backprop because I used an intermediate tensor instead.
Anyone have any ideas?