Unnececarry memory usage when calculating dot product of indexed tensors

Consider the following code:

import torch


@torch.jit.script
def foo(x, xi, y, yi):
    return torch.sum(x[xi] * y[yi], dim=1)


emb_size = 100
cooc_count = 10_000_000
word_count = 12543

device = "cuda"

xw = torch.randn(word_count, emb_size, device=device)
yw = torch.randn(word_count, emb_size, device=device)

xi = torch.randint(word_count, (cooc_count,), device=device)
yi = torch.randint(word_count, (cooc_count,), device=device)

dots = foo(xw, xi, yw, yi)

For those curious, this comes up when training word embeddings but that background is not really relevant, I’m more interested in this question in general.

When I run this on my computer the code crashes with a “CUDA out of memory” error, which makes sense because x[xi] and y[yi] expand to very large tensors of size (cooc_count, emb_size) with a bunch of repeated rows in them. This is an unneccesary allocation though, because those tensors are immediatly collapsed into a smaller tensor of size (cooc_count,) by the multiplication and sum which implement the rowwise dot product.

I was hoping the JIT would realize this and optimize foo to do this dot product directly while indexing. Is this a known limitation? Is there a way to rewrite the expression that avoids the intermediate allocation?