Bizzare PyTorch CUDA memory allocation failure on Linux

I am encountering a bizarre CUDA memory allocation error on Linux (and not Windows). I have two different machines - a Windows machine with an NVIDIA GeForce GTX 1050 with 4 GB of RAM and a Google Cloud NVIDIA Tesla T4 with 16 GB of RAM. My model (linked below) is a transformer which have a reputation for chewing up a lot of memory (see the Reformer paper for example).

On my Windows machine on the latest PyTorch version I can run the model fine in GPU memory for most training examples with a batch size of 4 (the training examples are variable sized). When I take the exact same model and training data over to the Linux VM I begin getting CUDA out of memory errors after around 40 steps. Increasing the batch size decreases the number of steps needed to obtain an out of memory error. This would indicate a memory leak to me, so I went ahead and made sure I wasn’t holding onto any references to allow the garbage collector to do its thing. I have added calls for clearing the GPU cache, running the Python garbage collector, etc. which doesn’t seem to help.

The super weird part is if I insert a call to

print(torch.cuda.memory_summary(device=None, abbreviated=True))

inside my exception handler for the CUDA out of memory error the issue goes away! I am even able to bump the batch size to 12 on the Linux instance, which is what I would expect from a system with higher GPU memory.

Model is here:

Any ideas?

Update:

If I stick

torch.cuda.memory_summary(device=None, abbreviated=True)

in a loop that I iterate over 25 times I can push up to a batch size of 16. Sleeping and waiting in a busy loop do not reproduce the same behaviour.

Without looping over this function invocation multiple times I see periods where I will get many CUDA allocation failures in a row. Looping results in occasional CUDA allocation failures which is what I would expect due to sometimes encountering input sequences of high length.

Could you explain a bit, what that means?
Does it mean, that running torch.cuda.memory_summary allows you to increase the batch size, while a sleep call does not?

That is correct. I have no idea why a call to torch.cuda.memory_summary stops the out of memory errors, but I am glad it does. Basically when I see a CUDA out of memory error I call memory_summary 25 times in a row then try one more time to feed the example to the network. Most of the time this fixes the out of memory error.

That’s indeed weird, as it should only call to this method to get the stats.
Does a torch.cuda.synchronize() instead of torch.cuda.memory_summary() yield the same behavior?