yes, as @ptrblck replied, the Conv is so large that CuDNN doesn’t support it, Hence it is still running via the slow_conv2d path that takes lots of working memory.
One way you can solve this is by tiling the convolution into patches.
Each individual patch of convolution actually ends up using CuDNN (and will be quite fast), and the overall computation will be exactly the same.
Here’s some sample code that I verified works correctly (it gives same output between F.conv2d
and tiled_conv2d
)
Here’s the tiled conv function
def tiled_conv2d(input, weight, bias, tile_size):
"""Compute the exact same function as Conv2D, but instead do it tile by tile, and account for border effects"""
# Initialize the output tensor
y_full = torch.zeros(input.size(0), input.size(1), input.size(2), input.size(3)).cuda()
overlap = weight.size(2) - 1 # Kernel size - 1
for i in range(0, input.shape[2], tile_size):
for j in range(0, input.shape[3], tile_size):
# Calculate the region of interest with overlap
start_i = max(i - overlap // 2, 0)
end_i = min(i + tile_size + overlap // 2, input.shape[2])
start_j = max(j - overlap // 2, 0)
end_j = min(j + tile_size + overlap // 2, input.shape[3])
# Extract the tile
tile = input[:, :, start_i:end_i, start_j:end_j]
# Process the tile
conv_tile = F.conv2d(tile, weight, bias, (1, 1), (1, 1), (1, 1), 1)
# Determine the region in the output tensor to update
# Adjust the placement considering the overlap
output_start_i = i
output_end_i = i + tile_size if i + tile_size <= input.shape[2] else input.shape[2]
output_start_j = j
output_end_j = j + tile_size if j + tile_size <= input.shape[3] else input.shape[3]
# Adjust the slicing of the convolved tile to match the output size
tile_i_start = 0 if i == 0 else overlap // 2
tile_i_end = conv_tile.shape[2] - (0 if i + tile_size >= input.shape[2] else overlap // 2)
tile_j_start = 0 if j == 0 else overlap // 2
tile_j_end = conv_tile.shape[3] - (0 if j + tile_size >= input.shape[3] else overlap // 2)
# Place the convolved tile in the output tensor
y_full[:, :, output_start_i:output_end_i, output_start_j:output_end_j] = conv_tile[:, :, tile_i_start:tile_i_end, tile_j_start:tile_j_end]
# y_full now contains the full convolved image
return y_full
And here’s using it
import torch
import torch.nn.functional as F
# Original image and kernel
input_size = 3072 # change this to a smaller size if you want to verify the correctness of F.conv2d with tiled_conv2d
input = torch.randn(1, 256, input_size, input_size).cuda()
weight = torch.randn(256, 256, 3, 3).cuda()
bias = torch.randn(256).cuda()
tile_size = 256 # use 256 x 256 image patches
y_full = tiled_conv2d(input, weight, bias, tile_size)
# Use a smaller input_size, and uncomment the next two lines to verify the correctness of tiled_conv2d
# y = F.conv2d(input, weight, bias, (1,1), (1,1), (1,1), 1) # OOMs
# torch.allclose(y, y_full)