Consider the following code:
import torch
@torch.jit.script
def foo(x, xi, y, yi):
return torch.sum(x[xi] * y[yi], dim=1)
emb_size = 100
cooc_count = 10_000_000
word_count = 12543
device = "cuda"
xw = torch.randn(word_count, emb_size, device=device)
yw = torch.randn(word_count, emb_size, device=device)
xi = torch.randint(word_count, (cooc_count,), device=device)
yi = torch.randint(word_count, (cooc_count,), device=device)
dots = foo(xw, xi, yw, yi)
For those curious, this comes up when training word embeddings but that background is not really relevant, I’m more interested in this question in general.
When I run this on my computer the code crashes with a “CUDA out of memory” error, which makes sense because x[xi]
and y[yi]
expand to very large tensors of size (cooc_count, emb_size)
with a bunch of repeated rows in them. This is an unneccesary allocation though, because those tensors are immediatly collapsed into a smaller tensor of size (cooc_count,)
by the multiplication and sum which implement the rowwise dot product.
I was hoping the JIT would realize this and optimize foo
to do this dot product directly while indexing. Is this a known limitation? Is there a way to rewrite the expression that avoids the intermediate allocation?