I recently encountered the word GEMM. I’m a bit confused about the usage of GEMM in Pytorch: how does it differ from the normal matrix-matrix multiplication? For example, I’ve read something about turning the convolution to a matrix multiplication, i.e., unfold + GEMM + reshape procedure. I’m wondering how is the GEMM implemented in Pytorch.
Suppose I have a Conv layer, if I first unfold the input to a 3D matrix (with the 1st dimension to be batch dimension), and reshape the weight to a 2D matrix, then perform a multiplication
conv_mat = torch.matmul(input_unfold, weight_reshape)
conv_mat back to the shape of the convolution output. Will this be regarded as using GEMM? If not, how to get GEMM involved?