F.conv2d with large tensor。Tried to allocate 81.00 GiB

Environment

Python 3.10.9
torch 2.0.1
gpu: Tesla P40

import torch
import torch.nn.functional as F
a = torch.randn(1, 256, 3072, 3072).cuda() #  need  9G+ 
b = torch.randn(256, 256, 3, 3).cuda()
c = torch.randn(256).cuda()                                                                                                                                   
y=F.conv2d(a, b, c, (1,1), (1,1), (1,1), 1) # try to allocate 81G

q:
why f.conv2d need 81G vram?
is there any method to calculate conv2d slowly without oom and don’t change original result?
thxs

Hey!

The default algorithm needs a workspace proportional to the kernel size which is massive here due to the number of channels.
One thing you should try is install cudnn on your machine, and make sure it is available torch.backends.cudnn.version(). IIRC it has at least one convolution algorithm that does not require any extra workspace.

image
cudnn is available;
how to set conv2d algorithm in pytorch. thxs

@Zhang_Jiguo while you cannot set the CuDNN algorithm, you can limit the max Workspace size that CuDNN is allowed to take to compute the convolution. That should be sufficient for your purpose.

You can set this workspace size using the environment variable CUDNN_CONV_WSCAP_DBG=4096, where for example 4096 here specifies 4096 megabytes.
In your code, you can specify this env variable before the import torch code.
For example:

import os
os.environ["CUDNN_CONV_WSCAP_DBG"] = 4096
import torch

Alternatively, you can specify it on the command-line:

CUDNN_CONV_WSCAP_DBG=4096 python your_script.py

References:

cuDNN won’t be used due to the large input size. needs_64bit_indexing_no_split will return true and thus use_cudnn will return false forcing the fallback to slow_conv2d_forward_cuda.

1 Like

env CUDNN_CONV_WSCAP_DBG=4096 python test_conv.py # still need 81G vram

can i setup the slow_conv2d_forward_cuda to run without OOM?

yes, as @ptrblck replied, the Conv is so large that CuDNN doesn’t support it, Hence it is still running via the slow_conv2d path that takes lots of working memory.

One way you can solve this is by tiling the convolution into patches.
Each individual patch of convolution actually ends up using CuDNN (and will be quite fast), and the overall computation will be exactly the same.

Here’s some sample code that I verified works correctly (it gives same output between F.conv2d and tiled_conv2d)

Here’s the tiled conv function

def tiled_conv2d(input, weight, bias, tile_size):
    """Compute the exact same function as Conv2D, but instead do it tile by tile, and account for border effects"""   
    # Initialize the output tensor
    y_full = torch.zeros(input.size(0), input.size(1), input.size(2), input.size(3)).cuda()
    
    overlap = weight.size(2) - 1  # Kernel size - 1

    for i in range(0, input.shape[2], tile_size):
        for j in range(0, input.shape[3], tile_size):
            # Calculate the region of interest with overlap
            start_i = max(i - overlap // 2, 0)
            end_i = min(i + tile_size + overlap // 2, input.shape[2])
            start_j = max(j - overlap // 2, 0)
            end_j = min(j + tile_size + overlap // 2, input.shape[3])

            # Extract the tile
            tile = input[:, :, start_i:end_i, start_j:end_j]

            # Process the tile
            conv_tile = F.conv2d(tile, weight, bias, (1, 1), (1, 1), (1, 1), 1)

            # Determine the region in the output tensor to update
            # Adjust the placement considering the overlap
            output_start_i = i
            output_end_i = i + tile_size if i + tile_size <= input.shape[2] else input.shape[2]
            output_start_j = j
            output_end_j = j + tile_size if j + tile_size <= input.shape[3] else input.shape[3]

            # Adjust the slicing of the convolved tile to match the output size
            tile_i_start = 0 if i == 0 else overlap // 2
            tile_i_end = conv_tile.shape[2] - (0 if i + tile_size >= input.shape[2] else overlap // 2)
            tile_j_start = 0 if j == 0 else overlap // 2
            tile_j_end = conv_tile.shape[3] - (0 if j + tile_size >= input.shape[3] else overlap // 2)

            # Place the convolved tile in the output tensor
            y_full[:, :, output_start_i:output_end_i, output_start_j:output_end_j] = conv_tile[:, :, tile_i_start:tile_i_end, tile_j_start:tile_j_end]

    # y_full now contains the full convolved image
    return y_full

And here’s using it

import torch
import torch.nn.functional as F

# Original image and kernel
input_size = 3072              # change this to a smaller size if you want to verify the correctness of F.conv2d  with tiled_conv2d
input = torch.randn(1, 256, input_size, input_size).cuda()
weight = torch.randn(256, 256, 3, 3).cuda()
bias = torch.randn(256).cuda()


tile_size = 256 # use 256 x 256 image patches
y_full = tiled_conv2d(input, weight, bias, tile_size)

# Use a smaller input_size, and uncomment the next two lines to verify the correctness of tiled_conv2d
# y = F.conv2d(input, weight, bias, (1,1), (1,1), (1,1), 1) # OOMs
# torch.allclose(y, y_full)
1 Like

THXS, it work on p40 with tolerance

torch.allclose(y, y_full, rtol=1e-3,atol=1e-3) # True
torch.allclose(y, y_full) # False
1 Like