Confusion about GEMM in Pytorch

Hi all,
I recently encountered the word GEMM. I’m a bit confused about the usage of GEMM in Pytorch: how does it differ from the normal matrix-matrix multiplication? For example, I’ve read something about turning the convolution to a matrix multiplication, i.e., unfold + GEMM + reshape procedure. I’m wondering how is the GEMM implemented in Pytorch.

Suppose I have a Conv layer, if I first unfold the input to a 3D matrix (with the 1st dimension to be batch dimension), and reshape the weight to a 2D matrix, then perform a multiplication

conv_mat = torch.matmul(input_unfold, weight_reshape)

Finally reshape conv_mat back to the shape of the convolution output. Will this be regarded as using GEMM? If not, how to get GEMM involved?

GEMM (general matrix multiplication) is used as a synonym for “matrix multiplication” and comes from the “BLAS world” if I’m not mistaken.
Have you seen GEMM in the docs somewhere? If so, I think we should replace it.

Thanks for your reply. No, I didn’t see GEMM in Pytorch docs. But I saw there are some matrix multiplication functions in Pytorch for BLAS operations, such as mm, bmm and etc. It seems they are dealing with matrix multiplications with different dimensions (2D/3D). I have two questions:

  1. From your description, matmul is also a kind of GEMM, right? Since it is performing matrix multiplication.
  2. How does matmul differ from functions for BLAS operations (e.g., mm, bmm) in terms of the performance? Does BLAS function have a faster implementation for matrix multiplication?
  1. Yes, these are just different names from different domains. Maybe similar to image processing filters and convolution kernels.

  2. matmul calls into different implementations as seen here.

1 Like

Thanks for your help!