I would like to do an incomplete Gemm, where we compute the dot products only with a fixed number of other other matrix rows (as computing the full matrix would not fit in memory).
Given two tensors: emb[B, T, C]
and ind[B, T, K]
where ind[b, t, :]
contains indices of the neighborhood of (b, t)
.
I would like to compute out[b, t, k] = \sum_c emb[b, t, c] * emb[b, ind[b, t, k], c]
.
Can I achieve this in PyTorch in a vectorized way without a Python loop over b
and t
and without materializing intermediate tensors?
This would enable a lot of experimentation in sparser attention patterns