I’m trying to write a funky 3D convolution like operation and running into performance issue. I’ve provided a minimal repro of what I’m doing.
def foo(): # Generate some data to test on volume_data = torch.rand(1, 2, 33, 33, 33, device = "cuda", requires_grad=True) funky_kernel = torch.rand(1, 729, 1, 1, 2, 1, 2, 1, 2, 2, 2, 2, device = "cuda", requires_grad=True) # Create the unfolded view of the volume volume_view = volume_data.unfold(dimension = 2, size = 2, step = 1) volume_view = volume_view.unfold(dimension = 3, size = 2, step = 1) volume_view = volume_view.unfold(dimension = 4, size = 2, step = 1) volume_view = volume_view.view(1, 1, 2, 16, 2, 16, 2, 16, 2, 2, 2, 2) # The actual calculation output = funky_kernel * volume_view output = output.sum(dim = (-3, -2, -1)) return output print(foo().shape)
The issue is that the main output calculation (the multiply and sum) is slow (~5X) compared to a native convolution. My hypothesis is that that’s because the multiplication manifests a large (8X) tensor before the
sum reduction runs. I would like to combine this multiplication and addition into a single operation.
I’ve been trying to use one of the many
bmm variations to do this, but I’m not too familiar with the daunting collection of initialism matrix functions and I haven’t been able to find one that fits my needs.
The promising idea seemed to be using
output = torch.matmul(volume_view.view(-1, 1, 1, 8), funky_kernel.view(1, -1, 8, 1))
However this doesn’t work, as
volume_view.view(-1, 1, 1, 8) complains that
view size is not compatible with input tensor's size and stride. Using
reshape would do a large memory copy which is unacceptable for performance and causes an out of memory error with an 11.4 GiB allocation.
Is there another fast way I can do this?
To state simply what I need: A single fast operation that does pointwise multiplication and summation over the last three dimensions, and that can operate on a view without requiring it to be reshaped in an incompatible way.