Training slows down for outlier variable-sized tensors and does not recover

I am training a neural network where the input tensors may have different sizes between minibatches. It seems like the moment a single minibatch that is larger than others is processed, it and all the minibatches after it are processed massively slower (1.6x slower), even if the sizes become constant for all minibatches after the one outlier.

It feels almost like PyTorch enters a “variable sequence length” mode and is prepared to handle all future minibatches as if they were variable length, even if that means a 60% performance drop. Any tips on how to eliminate this performance drop, at least for the common case with the smaller sized batches?

If I pad everything to the size of the first outlier I see, the performance doesn’t drop until the second outlier which is even larger. But the maximum possible outlier size is quite big, and I really don’t want every single minibatch to use this worst case size since that would waste a lot of compute on padded entries.

What I tried which didn’t work:
I’ve tried padding the first minibatch only to be the largest size, according to the performance tuning guide about pre-allocation. I still see problems on the first outlier, even though the first minibatch is larger than it. Even if I pad everything up to the first outlier to the largest size. the performance still tanks on the second outlier onward.
I’ve tried setting torch.backends.cudnn.benchmark = False, this also made no difference.

Example of this problem:
I print the variable-sized dimension of the tensor every iteration. The time spent gets logged around every 5 iterations. Note the moment I get size 320, all the iterations afterward become really slow, even if they are all still 296 like before.

296
296
296
296
296
296
296
296
07-11 21:55:30 INFO: | Epoch 1 | Step 8 [5/5000] | lr 0.000500 | loss 8.6681 | rc loss 8.6681 | kl loss 0.0000 | beta  0.30 | time 1.81s
296
296
296
296
296
07-11 21:55:34 INFO: | Epoch 1 | Step 13 [10/5000] | lr 0.001000 | loss 4.5439 | rc loss 4.5439 | kl loss 0.0000 | beta  0.30 | time 0.83s
296
296
296
296
296
07-11 21:55:38 INFO: | Epoch 1 | Step 18 [15/5000] | lr 0.001500 | loss 3.1491 | rc loss 3.1491 | kl loss 0.0000 | beta  0.30 | time 0.83s
296
296
296
296
296
07-11 21:55:42 INFO: | Epoch 1 | Step 23 [20/5000] | lr 0.002000 | loss 2.4775 | rc loss 2.4775 | kl loss 0.0000 | beta  0.30 | time 0.83s
296
296
296
320
296
07-11 21:55:47 INFO: | Epoch 1 | Step 28 [25/5000] | lr 0.002500 | loss 2.2208 | rc loss 2.2208 | kl loss 0.0000 | beta  0.30 | time 0.92s
296
296
296
296
296
07-11 21:55:54 INFO: | Epoch 1 | Step 33 [30/5000] | lr 0.003000 | loss 2.6465 | rc loss 2.6465 | kl loss 0.0000 | beta  0.30 | time 1.46s
296
296
296
296
296
07-11 21:56:01 INFO: | Epoch 1 | Step 38 [35/5000] | lr 0.003000 | loss 3.0696 | rc loss 3.0696 | kl loss 0.0000 | beta  0.30 | time 1.31s
296
296
296
296
296
07-11 21:56:08 INFO: | Epoch 1 | Step 43 [40/5000] | lr 0.003000 | loss 3.4430 | rc loss 3.4430 | kl loss 0.0000 | beta  0.30 | time 1.40s
296
296
296
296
296
07-11 21:56:14 INFO: | Epoch 1 | Step 48 [45/5000] | lr 0.003000 | loss 2.8866 | rc loss 2.8866 | kl loss 0.0000 | beta  0.30 | time 1.27s
296
296
296
296