How to prevent memory oscillation when using Pytorch?

If I use too much memory in Pytorch, I notice thrashing. E.g, the GPU memory usage fluctuates when calling nvidia-smi or gpustat. Ideally, I’d force Pytorch to never de-allocate memory, since this kills performance. Is there a way I can do this?

You could try to enable expandable_segments as described here.