Effective multiplication over last dimension

I have two torch tensors - A of shape (15, 100, 256) and B of shape (120, 2010, 256). How can I make effective multiplication over last dimension and get tensor of shape (15, 100, 120, 2010). I tried something like torch.einsum('ijk, mnk -> ijmn', A, B), but as I understand this method implicitly creates intermediate tensors and takes a lot of memory and time. I tried also opt_einsum library, but didn’t see big difference in time and memory usage comparing to torch.einsum

Of course, I can just make it in a loop, but I wanna get solution that will be efficient either in time and in memory usage. Thank you in advance

The problem is that because you’re broadcasting, and then it routes to a batch matmul, pytorch will materialize the fully broadcasted and expanded tensor into memory first and then run the batch matmul, which is not good for memory usage as you see.

The first thing I would try is to write it out without torch.einsum, i.e. just write it as unsqueeze + matmul and then wrap it with torch.compile and see if PyTorch is able to compile it.
Something like this:

import torch

# Sample dimensions
i, j, k, m, n = 2, 3, 4, 5, 6
A = torch.randn(i, j, k)
B = torch.randn(m, n, k)

@torch.compile
def outermul(A, B):
    # Reshape and expand A and B for batched multiplication
    A_expanded = A.unsqueeze(2).unsqueeze(3)  # Shape: (i, j, 1, 1, k)
    B_expanded = B.unsqueeze(0).unsqueeze(0)  # Shape: (1, 1, m, n, k)
    B_expanded = B_expanded.transpose(3, 4)  # Now Shape: (1, 1, m, k, n)

    # Multiply using torch.matmul, automatically summing over the last dimension of A and the second-to-last of B
    result = torch.matmul(A_expanded, B_expanded).squeeze(3)  # Resulting shape: (i, j, m, n)
    return result

result = outermul(A, B)
result.shape

If that doesn’t quite work out to your liking, if you are inclined to try something a bit more advanced, you should write a Triton kernel for this, it should be efficient, while also being concise. Here’s a couple of links for you to get started with Triton:

@smth The problem of matmul with unsqueeze is large memory usage. As I understand, with unsqueeze we add dimension and then matmul will create a copy of tensor under the hood and this tensor’s memory usage will be large.

that’s why I suggested to torch.compile it, because the compiler might generate a memory-efficient fused kernel for it if we’re lucky.

Otherwise, writing a custom Triton kernel is the only way to go.