Matrix multiplication source code?

I’m interested in finding out some specific implementation details of matrix multiplication in PyTorch. Where would one find the source code (CPU implementation and CUDA kernel) for PyTorch’s implementation of matrix multiplication? Specifically, where would one find the code implementing torch.bmm()?

I see that this question was already asked here, but not answered.

Thanks!

2 Likes

Here is a blog post how to get from Python PyTorch function to ATen.
For CPU, this will get you to bmm_cpu in LinearAlgebra.cpp. This in turn will call bmm_out_or_baddbmm_ in the same file. For very small operands it has a (somewhat lame) kernel it calls, for larger (floating point) operands it dispatches to MKL if available or it will use several mv.
For cuda, you end in LinearAlgebra.cu which (at the time of writing this) forwards to the THC legacy implementation baddbmm. For technical reasons (the old Torch7 “generic” macro based tensor dtype support) that in turn uses functions from THCBlas.cu to finally dispatch to cuBlas or rocBlas.

Phew. Now I know why no one answered in the question you linked.

Best regards

Thomas

5 Likes

Wow, very thorough @tom ! I really appreciate your taking the time. It’s quite the labyrinth, isn’t it.

1 Like