If I understand correctly, if you decompose torch.matmul of big matrices into many torch.matmul’s of smaller matrices, the processing with GPU becomes slower due to the inefficiency. For example, consider that you perform matmul between two 30000x30000 matrices. Also consider that you perform 90000 matmul’s between 100x100 sub-matrices of the former matrices. The former is obviously faster, but their FLOPS and memory consumption are the same. How can we make the speed of the latter closer to that of the former?
Often, you can exploit some known structure of matrices, and decomposing a matrix to many submatrices with various width and height, which gives you numerous matrix multiplications most of which involves zero matrix (and you know beforehand which of them are going to be zero matrix) and therefore can be filtered out (but there are still thousands of matmul’s to perform). Given a prescription of which pair to multiply, how to efficiently multiply these sub-matrices efficiently?