GPU MaxPool2d too slow with stride=1

I have a large sparse matrix [1, 1600, 512, 512] with values either 0 or 1 and I wanted to calculate the number of elements with value 1 along dim=1 within a window, let’s say [40, 40].

I solved it by passing the tensor with a nn.MaxPool2d((40, 40),stride=1) and summing along dim=1 in the end. This turned out to be very slow and consuming too much GPU memory (out of memory error). So, I divided the image into chunks along dim=1 using torch.chunk. It solved out of memory issues, but that also turned out to be slow as well.

I made a minimum running code to benchmark these approaches inspired by code from @ptrblck (One channel conv2d on GPU takes too much time - #4 by ptrblck) like this:

import torch
import time

torch.backends.cudnn.benchmark = True

nb_iter = 10
# dummy image
img = torch.randint(low=0, high=2, size=[1, 1600, 512, 512]).cuda().float()

kernel_shape = [40, 40]
max_pool2d = torch.nn.MaxPool2d((kernel_shape[0], kernel_shape[1]), stride=1)

## Max Pooling
# warmup
for _ in range(10):
    max_pool_op = max_pool2d(img)

# test
torch.cuda.synchronize()
tic = time.time()
for _ in range(nb_iter):
    max_pool_op = max_pool2d(img)
torch.cuda.synchronize()
toc = time.time() - tic
print("GPU MaxPool Time:", toc / nb_iter)

## Max Pooling Chunking
torch.cuda.empty_cache()
# warmup
for _ in range(10):
    max_pool_chunk_op = []
    x = torch.chunk(img, 16, dim=1)
    for chunk in x:
        # chunk = torch.tensor(chunk)
        op = max_pool2d(chunk)
        max_pool_chunk_op.append(op)
    max_pool_chunk_op = torch.cat(max_pool_chunk_op, dim=1)

# test
torch.cuda.synchronize()
tic = time.time()
for _ in range(nb_iter):
    max_pool_chunk_op = []
    x = torch.chunk(img, 16, dim=1)
    for chunk in x:
        # chunk = torch.tensor(chunk)
        op = max_pool2d(chunk)
        max_pool_chunk_op.append(op)
    max_pool_chunk_op = torch.cat(max_pool_chunk_op, dim=1)
torch.cuda.synchronize()
toc = time.time() - tic
print("GPU MaxPool Chunk Time:", toc / nb_iter)

print(
    torch.sum(torch.eq(max_pool_chunk_op.cpu(), max_pool_op.cpu()))
    / max_pool_op.cpu().nelement()
)

And it gave me these results:

GPU MaxPool Time: 5.212354826927185
GPU MaxPool Chunk Time: 5.250830364227295
tensor(1.)

I’m using GeForce GTX TITAN X GPU and Torch version 1.8.1 with CUDA 10.2. My GPU was always at 100% while running the code.

Is there any apparent reason that max-pooling is so slow(5 secs/iter)? In my application, the number of channels can be even greater, so it is currently a big bottleneck for me. Is there a way to make it faster (under a second)?