Matmul broadcasting makes copies?

Based on the docs, matmul will broadcast to make inputs compatible.
From this line, it seems the following script will expand and make copies (1000 copies) of the second tensor (to due contiguous()).

x = torch.rand(1000, 10, 3, 4)
y = torch.rand(10, 4, 5)
z = torch.matmul(x, y)

If this is correct, we should change the docs since we are assuming not making copies when broadcasting here.

2 Likes

Hi,

that is a great observation and good research to pinpoint the exact line that is problematic!

I think this has two parts:

  • There seems to be indeed something going on with matmul that shouldn’t. The timing below suggests that matmul has 4.5ms while even the naive mul + sum has 0.6ms with the below:
     x = torch.rand(1000, 10, 3, 4)
     y = torch.rand(10, 4, 5)
     y2 = y.unsqueeze(0)
     x_ = x.unsqueeze(-1)
     y_ = y2.unsqueeze(2)
     %timeit z = torch.matmul(x, y)
     %timeit z2 = (x_*y_).sum(-2)
  • The strategy matmul employs is to reshape the array into something to be fed to at::bmm. Actually, in your example, it is not possible to do this, as you cannot combine the stride 0 of the first axis and the stride of the second axis into one. So a proper fix probably boils down to reworking the matmul rather fundamentally. (Probably a fun project!)

If you wanted to, you could file a bug. I do believe there is a bug in the code, not just the docs.

Best regards

Thomas

1 Like

The CPU kernel is slow because we use a for loop over batch dimension. I think the main reason for making copies and utilizing bmm is that we don’t have a readily available kernel for the GPU op otherwise. Right now CUDA provides cublasSgemmBatched, cublasDgemmBatched and their strided versions for batched mm.

Just posted feature request here for cublas<T>gemmStridedBatched

GemmBatchedEx https://docs.nvidia.com/cuda/cublas/index.html#cublas-GemmBatchedEx should be able to do what you want, because it accepts an array of pointers, so B pointers could be repeating in a loop, but cublasGemmBatchedEx is not used in pytorch backend because preparing aforementioned lists of pointers is somewhat painful unless really necessary.

Actually Pytorch uses something similar prior to CUDA 8.0, and it’s still effective here. But as you said, allocating lots of pointers is not that efficient. For the toy example above, it’s allocating 30K pointers and actually slower than making copies + gemmStridedBatched. It would be really nice if we can do gemmStridedBatched without copying.

Hi @tom,

is naive mul + sum still faster than matmul in PyTorch 1.0 ? I tried your code and I got the opposite.

Thank you,

Mikaël

Is there still no memory and time-efficient way to do broadcasted matrix multiplication? It seems like this would be essential for something like a transformer network since otherwise dot-product attention would require you to copy both your keys and your queries? Is there a special trick to do dot-product attention in particular?

1 Like

I was trying to figure out why my benchmarking trace was full of copy_() calls. Turns out it was because of this matmul broadcasting issue. Now I use F.linear instead, which is a bit awkward, since everything has to be transposed, but it is much faster. All the copying is gone.

To be more precise: The point is that you need to structure your code so broadcasted matrix multiplication is done from the right, not the left. Here is a simple example:

[ins] In [3]: BS, DIM = 100, 100
         ...: A = torch.randn(DIM, DIM)
         ...: B = torch.randn(BS, DIM, DIM)

[ins] In [4]: %timeit A @ B
6.11 ms ± 117 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

[ins] In [5]: %timeit B @ A
4.02 ms ± 15.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)