I have two tensors, A and B, with the following dimensions:
- A.shape: [1024, 1, 1024, 1, 2048]
- B.shape: [1024, 1024, 1, 2048, 1]
I want to compute the matrix multiplication of these two tensors using A @ B
in PyTorch. The output is correct, but there is an issue with memory consumption. When this computation is performed in PyTorch, it tries to first broadcast the two tensors to the same dimensions. For example, it creates an intermediate tensor from A with dimensions [1024, 1024, 1024, 1, 2048], repeating the second dimension 1024 times. It does the same for B, resulting in dimensions [1024, 1024, 1024, 2048, 1]. These intermediate tensors are very large.
Is there a more efficient way to perform the multiplication without such broadcasting? Are there any other libraries that can handle this multiplication better than PyTorch?