Caching allocator reserving too much memory

Hi there,

I’m seeing a large increase in allocated memory (~10x) when I increase the input size of the feature map beyond ~4.3GB for my Conv3d layer in the script below. I’m increasing the feature map size by incrementing the value along dim=3 (x_dim=361 or 362).

With input size 4.29GB (x_dim=361), I get:

input GB:  4.29133696
max alloc. (GB):  12.874233344
max res. (GB):  12.880707584

And with input size 4.30GB (x_dim=362), I get:

input GB:  4.30322432
...
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 108.21 GiB (GPU 0; 79.21 GiB total capacity; 6.01 GiB already allocated; 72.66 GiB free; 6.01 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

Is something going wrong with the allocator?

Script

import torch
from torch import nn

# x_dim = 361   # works - max memory ~ 12.9 GB.
x_dim = 362     # breaks - tries to allocate ~ 110 GiB.
input_shape = (1, 64, x_dim, 370, 251)
device = torch.device('cuda:0')
input = torch.rand(input_shape)

input = input.half().to(device)
input_GB = input.numel() * input.element_size() / 1e9
print('input GB: ', input_GB)
input.requires_grad = False
layer = nn.Conv3d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1).half().to(device)
for i, param in enumerate(layer.parameters()):
    param_GB = param.numel() * param.element_size() / 1e9
    # print(f'param_{i} GB: {param_GB}')
    param.requires_grad = False
output = layer(input)

print('max alloc. (GB): ', torch.cuda.max_memory_allocated() / 1e9)
print('max res. (GB): ', torch.cuda.max_memory_reserved() / 1e9)

Versions

PyTorch version: 1.13.1+cu117
Is debug build: False
CUDA used to build PyTorch: 11.7
ROCM used to build PyTorch: N/A

OS: Red Hat Enterprise Linux Server release 7.9 (Maipo) (x86_64)
GCC version: (GCC) 10.2.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.17

Python version: 3.8.6 (default, Mar 29 2021, 14:28:48) [GCC 10.2.0] (64-bit runtime)
Python platform: Linux-3.10.0-1160.66.1.el7.x86_64-x86_64-with-glibc2.2.5
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA A100 80GB PCIe
Nvidia driver version: 515.65.01
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
CPU(s): 32
On-line CPU(s) list: 0-31
Thread(s) per core: 1
Core(s) per socket: 16
Socket(s): 2
NUMA node(s): 2
Vendor ID: GenuineIntel
CPU family: 6
Model: 106
Model name: Intel(R) Xeon(R) Gold 6326 CPU @ 2.90GHz
Stepping: 6
CPU MHz: 2900.000
BogoMIPS: 5800.00
Virtualization: VT-x
L1d cache: 48K
L1i cache: 32K
L2 cache: 1280K
L3 cache: 24576K
NUMA node0 CPU(s): 0,2,4,6,8,10,12,14,16,18,20,22,24,26,28,30
NUMA node1 CPU(s): 1,3,5,7,9,11,13,15,17,19,21,23,25,27,29,31
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc aperfmperf eagerfpu pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch epb cat_l3 invpcid_single intel_pt ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local dtherm ida arat pln pts avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq md_clear pconfig spec_ctrl intel_stibp flush_l1d arch_capabilities

Versions of relevant libraries:
[pip3] numpy==1.24.0
[pip3] pytorch-lightning==1.8.6
[pip3] torch==1.13.1
[pip3] torchaudio==0.13.1
[pip3] torchio==0.18.86
[pip3] torchmetrics==0.11.0
[pip3] torchvision==0.14.1
[conda] Could not collect

I’ve followed up with the corresponding issue here, and it is a current limitation of relying on a native “im2col” style kernel for batch-size 1 convs that require 64-bit indexing (this case is not currently supported by cuDNN).

1 Like