Performance issue for conv2d with 1D filter along a dim

I have a Gauss filter that I’m applying independently for the x and y dimensions of my data. Hence the filter is a 1D filter that is first applied to e.g. along the width and then along the height.

I noticed that there is significant performance differences depending on the dimension, and the dimension is different for CUDA or CPU. For CPU there’s a difference of 3s vs. 300ms! On CUDA the dimension that is slower is flipped with a difference of 24ms vs 4ms!

This is on Windows with version 2.2.2+cu121. I tested it with code below and got results as below. I can obviously flip dimensions to get a speedup, but I weary that I may be missing something fundamental about CPU and CUDA. I understand the CUDA result I think, because the last dimension has a stride of 1. I don’t understand the CPU result.

Test code:

import torch
import torch.nn.functional as F
from torch.utils.benchmark import Timer, Compare
from itertools import product


device = "cpu"
torch.set_num_threads(1)
torch.set_num_interop_threads(1)

gauss_kernel = torch.tensor(
    [[2.2306e-05, 4.2847e-05, 8.0049e-05, 1.4545e-04, 2.5706e-04, 4.4185e-04,
         7.3867e-04, 1.2011e-03, 1.8994e-03, 2.9215e-03, 4.3705e-03, 6.3590e-03,
         8.9989e-03, 1.2386e-02, 1.6580e-02, 2.1587e-02, 2.7336e-02, 3.3668e-02,
         4.0330e-02, 4.6987e-02, 5.3244e-02, 5.8680e-02, 6.2900e-02, 6.5576e-02,
         6.6493e-02, 6.5576e-02, 6.2900e-02, 5.8680e-02, 5.3244e-02, 4.6987e-02,
         4.0330e-02, 3.3668e-02, 2.7336e-02, 2.1587e-02, 1.6580e-02, 1.2386e-02,
         8.9989e-03, 6.3590e-03, 4.3705e-03, 2.9215e-03, 1.8994e-03, 1.2011e-03,
         7.3867e-04, 4.4185e-04, 2.5706e-04, 1.4545e-04, 8.0049e-05, 4.2847e-05,
         2.2306e-05]],
    dtype=torch.float32,
    device=device
)

data = torch.rand((1, 1, 3848, 3222), dtype=torch.float32, device=device)

x_kernel = gauss_kernel.view(1, 1, -1, 1)  # ZCXY=11K1
y_kernel = gauss_kernel.view(1, 1, 1, -1)  # ZCXY=111K


def run(do_x, do_y, flip_axis):
    if do_x:
        if flip_axis:
            F.conv2d(data.moveaxis(2, 3), x_kernel.moveaxis(2, 3), padding="same")
        else:
            F.conv2d(data, x_kernel, padding="same")
    if do_y:
        if flip_axis:
            F.conv2d(data.moveaxis(2, 3), y_kernel.moveaxis(2, 3), padding="same")
        else:
            F.conv2d(data, y_kernel, padding="same")


for x, y, flip in product(*([(True, False)] * 3)):
    if not x and not y:
        continue
    result = Timer(
        setup="from __main__ import run",
        stmt=f"run({x}, {y}, {flip})",
        label="2D conv",
        sub_label=f"x={int(x)}, y={int(y)}, flip={int(flip)}"
    ).timeit(5)
    print(result)

CPU results sorted by time:

2D conv: x=1, y=1, flip=1
  3.21 s

2D conv: x=1, y=1, flip=0
  3.06 s

2D conv: x=1, y=0, flip=1
  2.83 s

2D conv: x=0, y=1, flip=0
  2.76 s

2D conv: x=0, y=1, flip=1
  363.43 ms

2D conv: x=1, y=0, flip=0
  295.50 ms

CUDA results sorted by time:

2D conv: x=1, y=1, flip=1
  37.08 ms

2D conv: x=1, y=1, flip=0
  26.74 ms

2D conv: x=0, y=1, flip=1
  24.43 ms

2D conv: x=1, y=0, flip=0
  23.28 ms

2D conv: x=1, y=0, flip=1
  5.37 ms

2D conv: x=0, y=1, flip=0
  4.07 ms

Hi Matham!

I don’t have an explanation for the timing results you are seeing. However,
I have a suggestion for yet another timing experiment you could try.

Although packaged as a conv2d(), you have, in essence, a one-dimensional,
length-49 convolution kernel. Build from it the [49, 49] two-dimensional
“product” kernel, and apply it as a single conv2d() to data.

If your one-dimensional kernel were much smaller (say of length 3), this would
likely be faster than your current scheme. In your case, the [49, 49] product
kernel is significantly larger than the [49] one-dimensional kernel, so my
proposed scheme would use significantly more floating-point operations. But,
especially with gpus, the raw number of floating-point operations can be less
important than how “smoothly” those operations can be streamed through the
gpu’s floating-point pipelines.

It’s possible (but I don’t know how likely) that (depending on your gpu, etc.)
the single conv2d() call with the larger kernel could run faster than the two
conv2d() calls with the smaller kernels.

Just an idea …

Best.

K. Frank

I appreciate the suggestion. For the GPU, I think you’re right that the timing difference between 2 smaller conv vs. one large conv may favor the one large conv. I tested it to be sure (using the flipping technique to time the fastest version):

On my laptop GPU, 2 smaller conv took 14ms total vs. 850ms for the single large conv. On the CPU, it was 663ms for 2 smaller total vs. at least 5 min, after which I gave up waiting.

This code will run on CPU by some people, so having it fast there is also a priority for me!