torch Conv1d GPU memory spikes during backward prop. on smaller inputs

(Short reproducible code below.)

A very bizarre behavior while using torch.Conv1d:
when feeding a smaller input (below some threshold), the GPU memory usage spikes sharply during backward - an order of magnitude or more.

We hypothesize that this has to do with torch/cuda using different algorithms for the convolution depending on dimensions and available memory; problem is, that this results very undesired OOM errors on runtime.

import torch
from torch import nn
%load_ext pytorch_memlab

base_ch = 512
d_pos = 64


def print_gpu_mem_usage(prefix=""):
    print(f"{prefix}Peak memory: {torch.cuda.max_memory_allocated() / 1024 ** 3:.2f} GB"
          f" | {torch.cuda.max_memory_reserved() / 1024 ** 3:.2f} GB"
          f" (Current: {torch.cuda.memory_allocated() / 1024 ** 3:.2f} GB"
          f" | {torch.cuda.memory_reserved() / 1024 ** 3:.2f} GB)")


def isolated_conv(v):
    samp_conv = nn.Conv1d(base_ch + d_pos, 2 * base_ch, kernel_size=1, padding='valid').cuda()
    mn = samp_conv(v).mean()
    mn.backward()

%mlrun -f isolated_conv isolated_conv(torch.rand(5000, base_ch+d_pos, 11).cuda())

:active_bytes: :reserved_bytes: :line: :code:
all all
----peak---- -----peak----- ---- ----
108.00M 108.00 6 def isolated_conv(v):
328.00M 346.00M 7 mn = nn.Conv1d(…)
542.00M 562.00M 8 mn.backward()

however, switch n samples from 5000 to 4000 and it explodes:

%mlrun -f isolated_conv isolated_conv(torch.rand(4000, base_ch+d_pos, 11).cuda())

:active_bytes: :reserved_bytes: :line: :code:
all all
----peak---- -----peak----- ---- ----
86.00M 86.00 6 def isolated_conv(v):
260.00M 280.00M 7 mn = nn.Conv1d(…)
8.07G 8.25G 8 mn.backward()

Same happens also if I test the two on opposite order.

edits/add-it:
I tried torch empty_cache(), it doesn’t help. same for cudnn.benchmark=True/False.

Just read about torch.backends.cudnn.deterministic=True which slightly changes this behavior, still the problem occurs when I go down to 3000 samples. Going to try training with it and see if it at least doesn’t crash on OOM.

This runs on a docker, so if you can’t reproduce with the following versions I can share a Dockerfile.

torch == 2.0.1
pytorch-memlab == 0.3.0
Nvidia 2080Ti
Driver Version: 525.105.17
CUDA Version: 12.0

Could you post the code snippet resulting in the OOM as this should never happen since cuDNN won’t use an algorithms which requires a workspace larger than the available memory?

I can share the section in which the OOM happens. I tracked it down on TB profiler and can tell that it happens on the backward of this conv1d -
nn.Conv1d(base_ch+D_POS, 2*base_ch, kernel_size=1, padding='valid'),
which lead me to investigate why it happens on random batches during the run.

This function is part of a very large, complex graph implemented with PL/Torch/PG (which I can’t share unfortunately).

The input to the first, raw_conv is a long concatenated vector, 1 x in_channels x <order of 10^3-10^4> .
The input to the second, samp_conv when it crashes is an array of shape, <order of 10^3> x in_channels x 11 .

base_ch = 512
d_pos = 64
in_channels = 106
d_segment_feature = 512

self.raw_conv = nn.Sequential(
nn.Conv1d(in_channels=in_channels, out_channels=base_ch, kernel_size=1, padding=‘same’),
nn.ReLU(inplace=True),
nn.Conv1d(base_ch, base_ch, kernel_size=3, groups=base_ch // 32, padding=‘same’),
nn.ReLU(inplace=True),
nn.Conv1d(base_ch, base_ch, kernel_size=3, groups=base_ch // 32, padding=‘same’)
)

self.init_segment_position = nn.Sequential(
nn.Conv1d(5, d_pos // 2, kernel_size=3, padding=‘same’),
nn.ReLU(inplace=True),
nn.Conv1d(d_pos // 2, d_pos, kernel_size=3, padding=‘same’)
)

self.samp_conv = nn.Sequential(
nn.Conv1d(base_ch + d_pos, 2 * base_ch, kernel_size=1, padding=‘valid’),
nn.Conv1d(2 * base_ch, 2 * base_ch, kernel_size=3, padding=‘valid’, groups=2 * base_ch // 32, stride=2),
nn.ReLU(inplace=True),
nn.Conv1d(2 * base_ch, 2 * base_ch, kernel_size=3, padding=‘valid’, groups=2 * base_ch // 32, stride=2),
nn.ReLU(inplace=True),
nn.Conv1d(2 * base_ch, 2 * base_ch, kernel_size=2, padding=‘valid’, groups=2 * base_ch // 64, stride=2),
nn.ReLU(inplace=True),
nn.Conv1d(2 * base_ch, d_segment_feature,
kernel_size=1, padding=‘valid’, groups=d_segment_feature // 64),
)

short update: adding
torch.backends.cudnn.deterministic=True
doesn’t solve the problem, unfortunately …

Without a code snippet to reproduce the issue I won’t be able to help, unfortunately, since an OOM should never be raised by cuDNN. We are only using conv algorithms requiring a workspace that would fit into the GPU memory without causing OOM. If an algo would require more memory we are ignoring it and are never calling into the kernel and will not try to allocate the workspace.

@ptrblck voila, minimalistic example. Technical details follow.

This creates an OOM for me within a few iterations of the loop.

import torch
from torch import nn
from numpy.random import randint

def isolated_conv(v):
print(“FWD”)
mn = nn.Conv1d(512, 1024, kernel_size=1, padding=‘valid’).cuda()(v).mean()
print(“BWD”)
mn.backward()
print(“Done with conv”)

rand_int = lambda: randint(3000, 5000)
for ii in range(100):
n_samples = rand_int()
n_array_sq = rand_int()
print(f"iteration {ii} | n_samples: {n_samples} | n_array_sq: {n_array_sq}“)
print(f”#1, running 1st conv+bwd")
isolated_conv(torch.rand(n_samples, 512, 11).cuda())
print(f"#1, create temp tensor")
temp = torch.randn(n_array_sq, 1024).cuda()
print(f"#1, reshape temp")
temp = temp.reshape(n_array_sq, 512, 2)
print(f"#1, delete temp")
del temp

Specs:
This runs on a 2080Ti, in a docker derived from an official Torch docker:
FROM pytorch/pytorch:2.0.1-cuda11.7-cudnn8-runtime
which has -
Python 3.10.11
torch 2.0.1

@ptrblck - per your request, I sent a minimal reproducible example; any resolution?