I recently encountered the word GEMM. I’m a bit confused about the usage of GEMM in Pytorch: how does it differ from the normal matrix-matrix multiplication? For example, I’ve read something about turning the convolution to a matrix multiplication, i.e., unfold + GEMM + reshape procedure. I’m wondering how is the GEMM implemented in Pytorch.
Suppose I have a Conv layer, if I first unfold the input to a 3D matrix (with the 1st dimension to be batch dimension), and reshape the weight to a 2D matrix, then perform a multiplication
conv_mat = torch.matmul(input_unfold, weight_reshape)
conv_mat back to the shape of the convolution output. Will this be regarded as using GEMM? If not, how to get GEMM involved?
GEMM (general matrix multiplication) is used as a synonym for “matrix multiplication” and comes from the “BLAS world” if I’m not mistaken.
Have you seen GEMM in the docs somewhere? If so, I think we should replace it.
Thanks for your reply. No, I didn’t see GEMM in Pytorch docs. But I saw there are some matrix multiplication functions in Pytorch for BLAS operations, such as
bmm and etc. It seems they are dealing with matrix multiplications with different dimensions (2D/3D). I have two questions:
- From your description,
matmul is also a kind of GEMM, right? Since it is performing matrix multiplication.
- How does
matmul differ from functions for BLAS operations (e.g.,
bmm) in terms of the performance? Does BLAS function have a faster implementation for matrix multiplication?