When using torch.bmm() to multiply many (>10k) small 3x3 matrices, we hit a performance bottleneck apparently due to cuBLAS heuristics when choosing which kernel to call. For example, the colab notebook below shows that for 2^15 matrices the call takes 2s but only 0.5s for 2^16 matrices.
What’s the easiest way to fix this, keeping in mind that we’d like to keep the differentiability via autograd ?
Quite likely, some sort of doing your own is required. If neither bmm nor spelling out the contraction (for 10k x 3x3 that might be an option, probably not for very large ones) works for you, you might allocate the result array and fee batches through bmm_out.
Another alternative could be to see if CUTLASS works for you. Matrix multiplication should be simple to implement the backward for (which just is a couple of matrix multiplications itself).
By spelling of the contraction, do you mean writing the operation as an einsum ?
Using another library and writing the backward seems like a somewhat pain-free solution too, but I was wondering if there would be a way to indicate to cuBLAS which kernel to call. In the end, I suspect we might have to go that route.
No, einsum will itself use bmm, I thought of materializing the elementwise product and .sum.
(I do have a branch somewhere that uses TensorIterators for einsum instead. It’s so terrible on CPU (no AVX) that I didn’t look at GPU, but if you want to benchmark it on GPU, I can push it.)