Fast batched multiplication and sum over last 3 dimensions on arbitrary view without reshape

I’m trying to write a funky 3D convolution like operation and running into performance issue. I’ve provided a minimal repro of what I’m doing.

def foo():
  # Generate some data to test on
  volume_data = torch.rand(1, 2, 33, 33, 33, device = "cuda", requires_grad=True)
  funky_kernel = torch.rand(1, 729, 1, 1, 2, 1, 2, 1, 2, 2, 2, 2, device = "cuda", requires_grad=True)

  # Create the unfolded view of the volume
  volume_view = volume_data.unfold(dimension = 2, size = 2, step = 1)
  volume_view = volume_view.unfold(dimension = 3, size = 2, step = 1)
  volume_view = volume_view.unfold(dimension = 4, size = 2, step = 1)
  volume_view = volume_view.view(1, 1, 2, 16, 2, 16, 2, 16, 2, 2, 2, 2)

  # The actual calculation
  output = funky_kernel * volume_view
  output = output.sum(dim = (-3, -2, -1))
  return output

print(foo().shape)

The issue is that the main output calculation (the multiply and sum) is slow (~5X) compared to a native convolution. My hypothesis is that that’s because the multiplication manifests a large (8X) tensor before the sum reduction runs. I would like to combine this multiplication and addition into a single operation.

I’ve been trying to use one of the many matmul/bmm variations to do this, but I’m not too familiar with the daunting collection of initialism matrix functions and I haven’t been able to find one that fits my needs.

The promising idea seemed to be using matmul like

output = torch.matmul(volume_view.view(-1, 1, 1, 8), funky_kernel.view(1, -1, 8, 1))

However this doesn’t work, as volume_view.view(-1, 1, 1, 8) complains that view size is not compatible with input tensor's size and stride. Using reshape would do a large memory copy which is unacceptable for performance and causes an out of memory error with an 11.4 GiB allocation.

Is there another fast way I can do this?

To state simply what I need: A single fast operation that does pointwise multiplication and summation over the last three dimensions, and that can operate on a view without requiring it to be reshaped in an incompatible way.

einsum seems like it should be the answer, but I’m finding that it’s a lot slower (~25X!) than the basic multiply and sum.

This is the code I’m using for that.

(a, b) = torch.broadcast_tensors(a, b)
return torch.einsum("...ijk, ...ijk -> ...", a, b)

I’ve filed a bug for that at https://github.com/pytorch/pytorch/issues/32591