I’m trying to write a funky 3D convolution like operation and running into performance issue. I’ve provided a minimal repro of what I’m doing.
def foo():
# Generate some data to test on
volume_data = torch.rand(1, 2, 33, 33, 33, device = "cuda", requires_grad=True)
funky_kernel = torch.rand(1, 729, 1, 1, 2, 1, 2, 1, 2, 2, 2, 2, device = "cuda", requires_grad=True)
# Create the unfolded view of the volume
volume_view = volume_data.unfold(dimension = 2, size = 2, step = 1)
volume_view = volume_view.unfold(dimension = 3, size = 2, step = 1)
volume_view = volume_view.unfold(dimension = 4, size = 2, step = 1)
volume_view = volume_view.view(1, 1, 2, 16, 2, 16, 2, 16, 2, 2, 2, 2)
# The actual calculation
output = funky_kernel * volume_view
output = output.sum(dim = (-3, -2, -1))
return output
print(foo().shape)
The issue is that the main output calculation (the multiply and sum) is slow (~5X) compared to a native convolution. My hypothesis is that that’s because the multiplication manifests a large (8X) tensor before the sum
reduction runs. I would like to combine this multiplication and addition into a single operation.
I’ve been trying to use one of the many matmul
/bmm
variations to do this, but I’m not too familiar with the daunting collection of initialism matrix functions and I haven’t been able to find one that fits my needs.
The promising idea seemed to be using matmul
like
output = torch.matmul(volume_view.view(-1, 1, 1, 8), funky_kernel.view(1, -1, 8, 1))
However this doesn’t work, as volume_view.view(-1, 1, 1, 8)
complains that view size is not compatible with input tensor's size and stride
. Using reshape
would do a large memory copy which is unacceptable for performance and causes an out of memory error with an 11.4 GiB allocation.
Is there another fast way I can do this?
To state simply what I need: A single fast operation that does pointwise multiplication and summation over the last three dimensions, and that can operate on a view without requiring it to be reshaped in an incompatible way.