Performance of torch.bmm with different CUDA kernels

Good evening,

When using torch.bmm() to multiply many (>10k) small 3x3 matrices, we hit a performance bottleneck apparently due to cuBLAS heuristics when choosing which kernel to call. For example, the colab notebook below shows that for 2^15 matrices the call takes 2s but only 0.5s for 2^16 matrices.

What’s the easiest way to fix this, keeping in mind that we’d like to keep the differentiability via autograd ?

Minimal Example Colaboratory:

StackOverflow discussion on batched GEMMs:


Quite likely, some sort of doing your own is required. If neither bmm nor spelling out the contraction (for 10k x 3x3 that might be an option, probably not for very large ones) works for you, you might allocate the result array and fee batches through bmm_out.
Another alternative could be to see if CUTLASS works for you. Matrix multiplication should be simple to implement the backward for (which just is a couple of matrix multiplications itself).

Best regards


Thanks for your answer!

By spelling of the contraction, do you mean writing the operation as an einsum ?

Using another library and writing the backward seems like a somewhat pain-free solution too, but I was wondering if there would be a way to indicate to cuBLAS which kernel to call. In the end, I suspect we might have to go that route.


No, einsum will itself use bmm, I thought of materializing the elementwise product and .sum.
(I do have a branch somewhere that uses TensorIterators for einsum instead. It’s so terrible on CPU (no AVX) that I didn’t look at GPU, but if you want to benchmark it on GPU, I can push it.)

Best regards


For an example of using CUTLASS for batched matrix multiply look here It’s hardcoded to half data type, so some changes are required.
Cublas allows specifying algorithms, but pytorch uses “default” as the algo, in the hopes that it would be optimal. Apparently it’s not, in your case.
There’s also a new “matmul” interface in cublas (not yet integrated into pytorch) that gives somewhat finer control over selected algorithms,, it also might be useful, but integrating it into pytorch is some work.

1 Like