Convolution multiplies each patch of the image by kernel (patch * kernel
), and I wanted to try different version of it, for example (patch - kernel).abs()
, or weighted median filter, and so on.
So I made a version of convolution where those things can be implemented using unfolds and confirmed that it works the same as scipy.signal.convolve:
input = torch.randn(3, 32, 32)
kernel = torch.randn(1, 5, 5)
patched = input.unfold(1, 5, 1).unfold(2, 5, 1) # patched.shape = 3, 28, 28, 5, 5
convolved = (patched * kernel).sum((-1, -2)) # convolved.shape = 3, 28, 28
And I made a convolutional layer:
def conv2d_layer(input, kernel):
ksize = kernel.shape[-2:]
input = input.unsqueeze(-1).expand(*input.size(), out_channels).movedim(1, -1)
# input.shape = (batch_size, H, W, out_channels, in_channels)
patched = input.unfold(-4, ksize[0], 1).unfold(-4, ksize[1], 1)
# patched.shape = (batch_size, 28, 28, out_channels, in_channels, K1, K2)
return (patched * kernel).sum((-1, -2, -3)).movedim(-1, 1)
# return.shape = (batch_size, out_channels, 28, 28)
I’ve tested, and output matches torch.nn.functional.conv2d. However, my function uses 300 times more VRAM (for example one layer from medium size CNN, pytorch one uses 0.1 GB, while mine tries to allocate 32GB and fails).
Also, in terms of performance, it seems to be only slightly slower than PyTorch version on GPU, however for some reason it is extremely slow on CPU (like about a 1000 times slower than pytorch), but that isn’t an issue because I use GPU, only maybe it can help diagnose where the issue is.
I also tried to squish all of that into one line but that didn’t seem to help
def conv2d_layer2(input, kernel):
ksize = kernel.shape[-2:]
return (input.unsqueeze(-1).expand(*input.size(), out_channels).movedim(1, -1).unfold(-4, ksize[0], 1).unfold(-4, ksize[1], 1) * kernel).sum((-1, -2, -3)).movedim(-1, 1)
Are there any ways I could make this require less memory? I know you can implement convolution with unfolds and matrix multiplication but then I won’t be able to do my custom versions of convolution layer like (patched - kernel).abs()
and other things like that`