Hi all,
I recently encountered the word GEMM. I’m a bit confused about the usage of GEMM in Pytorch: how does it differ from the normal matrix-matrix multiplication? For example, I’ve read something about turning the convolution to a matrix multiplication, i.e., unfold + GEMM + reshape procedure. I’m wondering how is the GEMM implemented in Pytorch.
Suppose I have a Conv layer, if I first unfold the input to a 3D matrix (with the 1st dimension to be batch dimension), and reshape the weight to a 2D matrix, then perform a multiplication
GEMM (general matrix multiplication) is used as a synonym for “matrix multiplication” and comes from the “BLAS world” if I’m not mistaken.
Have you seen GEMM in the docs somewhere? If so, I think we should replace it.
Thanks for your reply. No, I didn’t see GEMM in Pytorch docs. But I saw there are some matrix multiplication functions in Pytorch for BLAS operations, such as mm, bmm and etc. It seems they are dealing with matrix multiplications with different dimensions (2D/3D). I have two questions:
From your description, matmul is also a kind of GEMM, right? Since it is performing matrix multiplication.
How does matmul differ from functions for BLAS operations (e.g., mm, bmm) in terms of the performance? Does BLAS function have a faster implementation for matrix multiplication?