The problem is that because you’re broadcasting, and then it routes to a batch matmul, pytorch will materialize the fully broadcasted and expanded tensor into memory first and then run the batch matmul, which is not good for memory usage as you see.
The first thing I would try is to write it out without torch.einsum
, i.e. just write it as unsqueeze + matmul
and then wrap it with torch.compile
and see if PyTorch is able to compile it.
Something like this:
import torch
# Sample dimensions
i, j, k, m, n = 2, 3, 4, 5, 6
A = torch.randn(i, j, k)
B = torch.randn(m, n, k)
@torch.compile
def outermul(A, B):
# Reshape and expand A and B for batched multiplication
A_expanded = A.unsqueeze(2).unsqueeze(3) # Shape: (i, j, 1, 1, k)
B_expanded = B.unsqueeze(0).unsqueeze(0) # Shape: (1, 1, m, n, k)
B_expanded = B_expanded.transpose(3, 4) # Now Shape: (1, 1, m, k, n)
# Multiply using torch.matmul, automatically summing over the last dimension of A and the second-to-last of B
result = torch.matmul(A_expanded, B_expanded).squeeze(3) # Resulting shape: (i, j, m, n)
return result
result = outermul(A, B)
result.shape
If that doesn’t quite work out to your liking, if you are inclined to try something a bit more advanced, you should write a Triton kernel for this, it should be efficient, while also being concise. Here’s a couple of links for you to get started with Triton: