Incomplete/sparse Gemm (for some sort of local attention) with only given indices

Maybe the question can be formulated also as: does PyTorch support dense gemm with a sparse output mask? (so that the output is sparse and only computed where the mask is non-zero)

Is torch.sparse.sampled_addmm — PyTorch 2.0 documentation doing this?

Hmm, seems so, but does not support batches :frowning: Implementation of torch.sparse.sampled_baddmm · Issue #105319 · pytorch/pytorch · GitHub

@albanD

1 Like