Context: I am looking at Wavelet transforms on 3D volumes. They are implemented by a series of convolutions of a 2-item vector along the different axes. The implementation that was provided used a kind of “diagonal” matrix for implementing this, which seemed inefficient to me given the high number of unnecessary calculations because of all the zeroes in the matrices. But when testing to implement this with Convolutions, I found that they are in fact slower. I tried using a Conv3D with singleton dimensions in the kernel as well as using Conv1D, both were slower.
I created the following code as a minimal example:
import torch
from torch.nn import functional as F
from itertools import product
import torch.utils.benchmark as benchmark
kernel = torch.tensor([0.7071, 0.7071],device="cuda")
def conv_3d(x):
return F.conv3d(x, kernel[None, None, :, None, None], padding=(0, 0, 0), stride=(2, 1, 1))
def conv_1d(x):
permuted = x.permute(0, 3,4,1, 2)
result = F.conv1d(permuted.reshape(-1, 1, size), kernel[None, None, :], padding=0, stride=2)
return result.reshape(1, size, size, 1, size//2).permute(0, 3, 4, 1, 2)
def conv_1d_groups(x):
permuted = x.permute(0,1, 3,4, 2)
result = F.conv1d(permuted.reshape(1, -1, size), kernel[None, None, :].expand(size**2,1,2), padding=0, stride=2, groups=size**2)
return result.reshape(1,1, size, size, size//2).permute(0, 1, 4, 2, 3)
def matmul(x):
return torch.matmul(x, matrix_kernel.T)
results = []
for f, size in product([matmul, conv_3d, conv_1d, conv_1d_groups],[256,512,1024]):
label = f.__name__
sub_label = f'{size}'
x = torch.randn(1,1,size,size,size, device="cuda")
matrix_kernel = torch.zeros((size//2,size),dtype=torch.float32).cuda()
for i in range(size//2):
matrix_kernel[i,i*2:i*2+2] = kernel
results.append(benchmark.Timer(
stmt='f(x)',
globals={'x': x, 'f': f, 'matrix_kernel': matrix_kernel, 'kernel': kernel, 'size': size},
# num_threads=num_threads,
label="Test of separable convolution",
sub_label=sub_label,
description=label,
).blocked_autorange(min_run_time=1))
compare = benchmark.Compare(results)
compare.print()
which results in this (slightly reformatted):
[---------------- Test of separable convolution ----------------]
| matmul | conv_3d | conv_1d | conv_1d_groups
256 | 255.3 | 724.3 | 811.4 | 828.7
512 | 3613.9 | 5686.9 | 10500.9 | 10617.8
1024 | 53339.8 | 45561.8 | 87878.2 | 88775.5
Times are in microseconds (us).
Does anybody know why this is and what alternative there might be? In the case of the Conv1D I think this is mainly caused by the tensor copying needed to achieve the correct input shape. And for the Conv3D code I am assuming that the operation is not optimized for this simple kernel.
For reference, when I remove the permutations in the Conv1D, I get the following result:
[---------------- Test of separable convolution ----------------]
| matmul | conv_3d | conv_1d | conv_1d_groups
1 threads: ------------------------------------------------------
256 | 256.3 | 724.6 | 152.3 | 173.0
512 | 3599.6 | 5702.9 | 1196.3 | 1340.6
1024 | 53334.2 | 45847.1 | 9542.4 | 10715.2
Times are in microseconds (us).
My question is more out of curiosity and understanding, since all implementations provide enough performance for my use case.