The answer is: kind-of. I can give you an update.
Someone opened an issue on gitlab, after encountering what i thought is a superset of this problem. He saw that with varying input sizes, the first 20 iterations would be slow. On the other hand with fixed input size, only the first 2 iterations would be slow (and the second one, unacceptably so, much like the reason for my question here).
link: The first 20 loops of inference are Extremely Slow on C++ libtorch v1.8.1 · Issue #56245 · pytorch/pytorch · GitHub
He was referred to a solution to avoid the 20 slow iterations, which was to to decrease the optimization depth with
link: JIT makes model run x14 times slower in pytorch nightly · Issue #52286 · pytorch/pytorch · GitHub
Now, the problem for me was that the latter solution did not fix the 2nd iteration. However as it turns out, setting the baliout_depth to 0 would solve just that.
I’m not sure if this is an acceptable/recommended solution, and/or if there’s a better fix to be coming in a future version of pytorch (as i see ticket 52286 is still open) however maybe you want to try this.
So in conclusion, try:
or if you are writing C++:
torch::jit::getBailoutDepth() = 0;
and see if it works for you without too much penalty