Out of Memory Error passing from PyTorch 2.5.1+ cu121 to PyTorch 2.9 + cu128

The same 3D registration script for medical imaging (VoxelMorph-style UNet with Conv3d upsampling blocks) runs fine on PyTorch 2.5.1+cu121 but OOMs immediately on PyTorch 2.9+cu128 during a Conv3d in the decoder. The 2.9 build appears to select a cuDNN algorithm that requires a ~27 GiB workspace on a 24 GiB GPU.

The two environments differ only in their PyTorch and CUDA versions; everything else is substantially identical.

Here below I post two examples of the runs where I print the memory allocation at each step. Using env1, the script runs correctly, as at the first step it is allocated just 2.2 GB and reserved 17.6 GB.

Using env2, the script crashes into OOM error, since it tries to allocate 27 GBs from the very beginning.

EXAMPLE:
env1 PyTorch 2.5.1+cu121

13:57:45 - Starting. Patient: 101, Fixed suffix: 0min Moving suffix: 2h
13:57:45 - [startup] CPU RSS: 592.20 MB, VMS: 5.79 GB
13:57:45 - [startup] GPU0 alloc: 0.00 B, reserved: 0.00 B, peak alloc: 0.00 B, total: 23.63 GB
13:57:45 - [after data load] CPU RSS: 965.62 MB, VMS: 16.28 GB
13:57:45 - [after data load] GPU0 alloc: 128.00 MB, reserved: 128.00 MB, peak alloc: 128.00 MB, total: 23.63 GB
13:57:45 - STEP 1: Starting Global Registration
13:57:47 - [step1 epoch 1] CPU RSS: 2.28 GB, VMS: 36.35 GB
13:57:47 - [step1 epoch 1] GPU0 alloc: 2.20 GB, reserved: 17.56 GB, peak alloc: 16.04 GB, total: 23.63 GB
13:57:54 - Reg Step - Epoch [10/300], Combined Loss: -0.254139, NCC Loss: -0.822846, Gradient Loss: 0.030861, SSIM Loss: 0.313276, , DICE Coefficient: 0.7621
13:57:54 - [step1 epoch 10] CPU RSS: 2.28 GB, VMS: 37.89 GB
13:57:54 - [step1 epoch 10] GPU0 alloc: 2.20 GB, reserved: 19.06 GB, peak alloc: 16.39 GB, total: 23.63 GB

env2 PyTorch 2.9+cu128
12:39:49 - Starting. Patient: 101, Fixed suffix: 0min Moving suffix: 2h
12:39:49 - [startup] CPU RSS: 721.36 MB, VMS: 7.21 GB
12:39:49 - [startup] GPU0 alloc: 0.00 B, reserved: 0.00 B, peak alloc: 0.00 B, total: 23.63 GB
12:39:49 - [after data load] CPU RSS: 1.07 GB, VMS: 17.69 GB
12:39:49 - [after data load] GPU0 alloc: 128.00 MB, reserved: 128.00 MB, peak alloc: 128.00 MB, total: 23.63 GB
12:39:50 - STEP 1: Starting Global Registration
Traceback (most recent call last):
[…]
File “xxx”, line 641, in forward
y = self.decoder3 # Upsample
^^^^^^^^^^^^^^^^^^
[…]
File “xxx”, line 712, in _conv_forward
return F.conv3d(
^^^^^^^^^
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 27.00 GiB. GPU 0 has a total capacity of 23.63 GiB of which 7.78 GiB is free. Including non-PyTorch memory, this process has 15.20 GiB memory in use. Of the allocated memory 7.85 GiB is allocated by PyTorch, and 6.88 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)