I have an algorithm that performs a lot of 1D convolutions with kernels that are all ones over 2D images. First I perform them rowwise and then columnwise (here I’m only showing only one of the convolutions) convolutions. Each of those 1D convolutions is equivalent to summing the elements over a sliding window.
My first implementation uses Conv2D. I thought that Conv2D was doing a lot of extra work doing the patchXkernel multiplications because the kernel is all ones so those could be skipped. With this idea in mind I made an implementation that uses unfold/fold. The idea being that this would completely bypass the multiplication by the all ones kernel and leave me only with the addition over the patch. Resulting in about halve the work. However, this second approach is significantly slower even when running the unfold/fold solution through torch.compile first.
My only explanation for this behaviour is that the underlying C++/CUDA implementation is much more optimized, more than anything that torch with unfold/fold+compile can aspire to. Even when accounting for the extra work.
Here is the code for the methods
import time
import torch
import torch.nn as nn
from torch.nn.functional import fold, unfold
class OneDConv(nn.Module):
def __init__(self, weights):
super().__init__()
kernel1 = tuple(weights.shape)
self.conv1 = nn.Conv2d(
1,
1,
kernel_size=kernel1,
stride=1,
padding=(int((kernel1[0] - 1) / 2), int((kernel1[1] - 1) / 2)),
bias=False,
)
self.conv1.weight = nn.Parameter(weights.unsqueeze(0).unsqueeze(0))
def forward(self, img):
with torch.no_grad():
out = self.conv1(img)
return out
@torch.compile
def my_1d_conv(t):
img_size = (1024, 1024)
rad = 32
kernel_size = (1, rad + 1)
return fold(
unfold(t, kernel_size=kernel_size, padding=(1, rad // 2)).contiguous(),
output_size=img_size,
kernel_size=kernel_size,
padding=(1, rad // 2),
)
if __name__ == "__main__":
img = (
torch.ones([1024, 1024], device="cuda", dtype=torch.float32)
.unsqueeze(0)
.unsqueeze(0)
)
n_iters = 20
# CONV2D METHOD
rad = 32
conv = OneDConv(torch.ones([1, rad + 1], dtype=torch.float32, device="cuda"))
# this one is to just warm up things
conv(img)
s = time.time()
for _ in range(n_iters):
first_method = conv(img)
e = time.time()
time_us = (e - s) * 1e6
print("time with conv2D:", time_us / n_iters)
# UNFOLD/FOLD METHOD
# this one is to just warm up things
other_conv = my_1d_conv(img)
s = time.time()
for _ in range(n_iters):
second_method = my_1d_conv(img)
e = time.time()
time_us = (e - s) * 1e6
print("time per unfold/fold:", time_us / n_iters)
print(torch.sum(torch.abs(first_method - second_method)))
Here is the output from my 2080Ti. Timings are in microseconds and averaged over 20 runs
time with conv2D: 31.518936157226562
time per unfold/fold: 87.82148361206055
tensor(0., device='cuda:0')
I would imagine that if I got my hands dirty with CUDA I could get the all ones kernel convolution to run much faster but I’d rather avoid it.
I have also observed that as I make the image larger and larger the unfold/fold method crashes much sooner due to an OOM error. Not completely unexpected as I’m generating all those patches explicitly. Although conceptually both methods are the same whatever is done by Conv2D must be smarter in every way.
My questions are:
- What is the explanation for the observed behaviour
- Is there a way to achieve what I want within pytorch? I.e. Having a much faster convolution for the all ones kernel case or make the unfold/fold approach realize the gains I was expecting.
Thanks!