Incomplete/sparse Gemm (for some sort of local attention) with only given indices

I would like to do an incomplete Gemm, where we compute the dot products only with a fixed number of other other matrix rows (as computing the full matrix would not fit in memory).

Given two tensors: emb[B, T, C] and ind[B, T, K] where ind[b, t, :] contains indices of the neighborhood of (b, t).

I would like to compute out[b, t, k] = \sum_c emb[b, t, c] * emb[b, ind[b, t, k], c].

Can I achieve this in PyTorch in a vectorized way without a Python loop over b and t and without materializing intermediate tensors?

This would enable a lot of experimentation in sparser attention patterns

Maybe the question can be formulated also as: does PyTorch support dense gemm with a sparse output mask? (so that the output is sparse and only computed where the mask is non-zero)

Is torch.sparse.sampled_addmm — PyTorch 2.0 documentation doing this?

Hmm, seems so, but does not support batches :frowning: Implementation of torch.sparse.sampled_baddmm · Issue #105319 · pytorch/pytorch · GitHub


