Matrix multiplication implementation in PyTorch

I read about torch.autocast and how FP16 matrix multiplication is faster than FP32 on CUDA. However, on my Mac M1 (Intel chip), a 100x100 matrix multiplication takes 50 times longer in FP16 than FP32. Reducing the matrix size decreases the computation time, but FP32 still remains faster.

I’m trying to understand how torch.mm is implemented, but I couldn’t find the actual implementation code. Could you guide me on where to find it?

Afaik, you need a GPU with tensor cores to be able to benefit from faster operations with float16.

Yes, you are correct that in GPUs, F16 is faster than F32. However, I need to understand why the opposite is true for CPUs. F16 occupies half the memory of F32, and logically, when the same operation is applied to X and X/2, I would expect X/2 to operate faster. But why is this not the case for CPUs?

Hey!

Long story short, this is where most of the mm calls end up at on MPS device: pytorch/aten/src/ATen/native/mps/operations/LinearAlgebra.mm at 2ed4d65af0a1993c0df7b081f4088d0f3614283e · pytorch/pytorch · GitHub
And for CPU it ends up in pytorch/aten/src/ATen/native/CPUBlas.cpp at main · pytorch/pytorch · GitHub

As you can see in both cases, we just end up calling into Apple-provided libraries (either MPS kernel or Accelerate gemm kernel). So it is most likely that the concern is within these libraries.
In some cases (like the Metal gemm you can find) we write our own kernels when they behave better than vendor-provided ones but we try to avoid unless really necessary.

In your particular case, it is most likely that the library ends up doing a lot of expensive upcasting for no good reason and thus is slower :confused:
That being said, there can also sometimes be packaging issues with Accelerate/MKL/neither so if you see significant differences in an easy-to-reproduce environment, please submit an issue to we can track it and try and work around it!

1 Like

I’ve been working on aarch64 GEMM-related stuff lately, so I’d like to provide some additional color.

  1. The MPS device does not appear to be relevant to this example.

  2. For CPU, the half @ half → half GEMM implementation in CPUBlas (which I believe is the relevant one) does not call into Accelerate, because (to the best of my knowledge) Accelerate does not support half or bfloat16 for GEMM. Instead, it calls into gemm_stub, which is implemented in ATen/native/cpu/BlasKernel.cpp. From there, depending on argument transposition, you will hit one of the naive GEMM kernels in that file. IIRC, the best one is gemm_transa_ (corresponding to the most common transposition case in LLMs, activations @ weights.t()), which still isn’t very good but at least uses our accelerated fp16 dot product kernel in ATen/native/cpu/ReducedPrecisionFloatGemvFastPathKernel.cpp. Note that matrix-vector multiplication of the form vector @ matrix.t() should hit the GEMV kernel in ReducedPrecisionFloatGemvFastPathKernel.cpp as a special case, because it is relevant to LLM decoding.

I am interested in getting better GEMMs for half and bfloat16 into PyTorch core, but 1) I am struggling with prioritization given that, for LLM inference, it is unclear to me personally why 8-bit quantization wouldn’t be the obvious better performance/accuracy tradeoff instead, 2) our custom matrix multiplication infrastructure lives in torchao under torchao/experimental/kernels/cpu/aarch64 (and only contains low-bit quantized kernels, but that’s the infrastructure we’d likely extend to half/bfloat16) and therefore currently isn’t usable from PyTorch core, and 3) XNNPACK has fp16 kernels but bf16 is planned-but-not-implemented, so if we wanted to simply use XNNPACK we would be incentivized to wait for their bf16 implementation. The issue seems to be that most BLAS implementations don’t have 16-bit floating-point support.

P.S. I meant to provide a lot more links, but as a new user on this forum I am prohibited from doing so.

1 Like