CPU version of `torch.cummax` is slow

For context, I was implementing a modified version of CenterNet, which uses a module called Center/CornerPooling.
This module is basically cummax for each of the directions.
In my model, there are only 4 cummax calls.

I noticed that the inference is very slow on the CPU, about 500-1000ms per image.
After profiling my model, I noticed that about 1/4-1/3 of the inference time is torch.cummax.
I decided to replace the cpp CornerPooling implementation from GitHub - bishwarup307/cornerpool: Cornerpool in Pytorch and profile again.
The inference time was reduced a lot. But the CornerPooling from there triggers some warnings on the backward (which I don’t really understand). So I reimplemented another version with torch.unbind and torch.stack.

After that, I did some benchmarking with the implementations.
The benchmark was performed on the tensor torch.rand(1, 128, 256, 256) with all three implementations using timeit.
The cpp version and the unbind-stack version perform virtually the same, but the cummax version is 10x slower.
(Reference is the cpp version, because I use it to verify the correctness)

On CPU:
- Reference impl:		32.58 ± 0.94ms
- Cummax impl:		328.65 ± 7.42ms
- Unbind-stack impl:	35.17 ± 1.05ms
On CUDA:
- Reference impl:		1.88 ± 0.05ms
- Cummax impl:		0.03 ± 0.02ms
- Unbind-stack impl:	2.95 ± 1.79ms

What do you think, did I miss something?

Here’s the code to replicate this experiment.

import timeit

import numpy as np
import torch
from cornerpool import BottomPool, LeftPool, RightPool, TopPool
from torch import Tensor
from torch.jit import script


# Function for implementation verification
def correction_assert(f1, f2, x):
    assert (f1(x) == f2(x)).all()


def take_grad(f):
    def wrapped(x):
        x = x.clone()
        x.requires_grad = True
        loss = torch.square(f(x)).mean()
        loss.backward()
        return x.grad

    return wrapped


# Corner pooling: unbind-stack version
@script
def corner_pool(x: Tensor, dim: int, flip: bool):
    sz = x.size(dim)
    outputs = list(x.unbind(dim))

    for i in range(1, sz):
        if flip:
            i_in = sz - i
            i_out = sz - i - 1
        else:
            i_in = i - 1
            i_out = i
        outputs[i_out] = torch.maximum(outputs[i_out], outputs[i_in])

    return torch.stack(outputs, dim=dim)


@script
def top_pool(x: Tensor) -> Tensor:
    return corner_pool(x, -2, True)


@script
def bottom_pool(x: Tensor) -> Tensor:
    return corner_pool(x, -2, False)


@script
def left_pool(x: Tensor) -> Tensor:
    return corner_pool(x, -1, True)


@script
def right_pool(x: Tensor) -> Tensor:
    return corner_pool(x, -1, False)


# Corner pooling: cummax version
@script
def corner_pool_cm(img, dim: int, flip: bool):
    if flip:
        img = torch.flip(img, dims=(dim,))
    pooled, _ = torch.cummax(img, dim=dim)
    if flip:
        pooled = torch.flip(pooled, dims=(dim,))
    return pooled


@script
def top_pool_cm(x: Tensor) -> Tensor:
    return corner_pool_cm(x, -2, True)


@script
def bottom_pool_cm(x: Tensor) -> Tensor:
    return corner_pool_cm(x, -2, False)


@script
def left_pool_cm(x: Tensor) -> Tensor:
    return corner_pool_cm(x, -1, True)


@script
def right_pool_cm(x: Tensor) -> Tensor:
    return corner_pool_cm(x, -1, False)


# Check for correctness
img = torch.rand(1, 32, 48, 48)

# Result check
correction_assert(top_pool, TopPool(), img)
correction_assert(bottom_pool, BottomPool(), img)
correction_assert(left_pool, LeftPool(), img)
correction_assert(right_pool, RightPool(), img)
correction_assert(top_pool_cm, TopPool(), img)
correction_assert(bottom_pool_cm, BottomPool(), img)
correction_assert(left_pool_cm, LeftPool(), img)
correction_assert(right_pool_cm, RightPool(), img)

# Gradient check
correction_assert(take_grad(top_pool), take_grad(top_pool_cm), img)
correction_assert(take_grad(bottom_pool), take_grad(bottom_pool), img)
correction_assert(take_grad(left_pool), take_grad(left_pool), img)
correction_assert(take_grad(right_pool), take_grad(right_pool), img)

# Benchmark (CPU)
# Reference impl:		32.58 ± 0.94ms
# Cummax impl:		328.65 ± 7.42ms
# Unbind-stack impl:	35.17 ± 1.05ms
img = torch.rand(1, 128, 256, 256)
bottom_pool_1 = BottomPool()
bottom_pool_2 = bottom_pool_cm
bottom_pool_3 = bottom_pool
repeat = 3
number = 10
scale = 1000 / number
opt = {"globals": globals(), "repeat": repeat, "number": number}
with torch.no_grad():
    r = timeit.repeat("bottom_pool_1(img)", **opt)
    m = (np.mean(r) * scale).round(decimals=2)
    s = (np.std(r) * scale).round(decimals=2)
    print(f"Reference impl:\t\t{m} ± {s}ms")

    r = timeit.repeat("bottom_pool_2(img)", **opt)
    m = (np.mean(r) * scale).round(decimals=2)
    s = (np.std(r) * scale).round(decimals=2)
    print(f"Cummax impl:\t\t{m} ± {s}ms")

    r = timeit.repeat("bottom_pool_3(img)", **opt)
    m = (np.mean(r) * scale).round(decimals=2)
    s = (np.std(r) * scale).round(decimals=2)
    print(f"Unbind-stack impl:\t{m} ± {s}ms")

# Bench mark (CUDA)
# cummax wins this time, by 100 times.
# Reference impl:		1.88 ± 0.05ms
# Cummax impl:		0.03 ± 0.02ms
# Unbind-stack impl:	2.95 ± 1.79ms
img = torch.rand(1, 128, 256, 256).cuda()
repeat = 3
number = 10
scale = 1000 / number
opt = {"globals": globals(), "repeat": repeat, "number": number}
with torch.no_grad():
    r = timeit.repeat("bottom_pool_1(img)", **opt)
    m = (np.mean(r) * scale).round(decimals=2)
    s = (np.std(r) * scale).round(decimals=2)
    print(f"Reference impl:\t\t{m} ± {s}ms")

    r = timeit.repeat("bottom_pool_2(img)", **opt)
    m = (np.mean(r) * scale).round(decimals=2)
    s = (np.std(r) * scale).round(decimals=2)
    print(f"Cummax impl:\t\t{m} ± {s}ms")

    r = timeit.repeat("bottom_pool_3(img)", **opt)
    m = (np.mean(r) * scale).round(decimals=2)
    s = (np.std(r) * scale).round(decimals=2)
    print(f"Unbind-stack impl:\t{m} ± {s}ms")

Hi ndgnuh!

I can reproduce the slowness you see. @ptrblck: Is this a performance bug?

Here are some timings from a simplified version of your script:

2.0.1
11.8
GeForce GTX 1050 Ti
cpu equal: True
cpu cummax time:       2.887537717819214
cpu corner_pool time:  0.34276819229125977
gpu equal: True
gpu cummax time:       0.015532255172729492
gpu corner_pool time:  0.04727458953857422

And here is the script itself:

import torch
print (torch.__version__)
print (torch.version.cuda)
print (torch.cuda.get_device_name())

from time import time

_ = torch.manual_seed (2023)

def corner_pool(x: torch.Tensor, dim: int, flip: bool):
    sz = x.size(dim)
    outputs = list(x.unbind(dim))
    
    for i in range(1, sz):
        if flip:
            i_in = sz - i
            i_out = sz - i - 1
        else:
            i_in = i - 1
            i_out = i
        outputs[i_out] = torch.maximum(outputs[i_out], outputs[i_in])
    
    return torch.stack(outputs, dim=dim)

img = torch.rand (1, 128, 256, 256)

cmA = torch.cummax (img, dim = -2)[0]
cmB = corner_pool (img, -2, False)
print ('cpu equal:', torch.equal (cmA, cmB))

t0 = time()
for  i in range (10):
    cmA = torch.cummax (img, dim = -2)[0]

print ('cpu cummax time:      ', time() - t0)

t0 = time()
for  i in range (10):
    cmB = corner_pool (img, -2, False)

print ('cpu corner_pool time: ', time() - t0)

img = img.cuda()

cmA = torch.cummax (img, dim = -2)[0]
cmB = corner_pool (img, -2, False)
print ('gpu equal:', torch.equal (cmA, cmB))

torch.cuda.synchronize()
t0 = time()
for  i in range (10):
    cmA = torch.cummax (img, dim = -2)[0]

torch.cuda.synchronize()
print ('gpu cummax time:      ', time() - t0)

torch.cuda.synchronize()
t0 = time()
for  i in range (10):
    cmB = corner_pool (img, -2, False)

torch.cuda.synchronize()
print ('gpu corner_pool time: ', time() - t0)

Best.

K. Frank

Thanks for pinging!
The implementation should already be parallel via the usage of the TensorIterator as seen here, but I’m not familiar enough with the CPU backend.

@ndgnuh would you mind creating an issue on GitHub pointing to this thread so that we could track and fix it? Also, thanks for reporting it!

Sorry for the late reply.

As suggested, I have opened an issue here.