Sparse Tensor implementation (sparsify last dimension only)

A. batched sparse tensors

The problem could be solved with a batched sparse tensor implementation (a hybrid tensor where the first dimension is dense and the other(s) are sparse);

  1. To multiply a dense tensor (of N-1 dim) with a batched sparse representation (2 dim; 1 dense, 1 sparse) of the underlying hybrid tensor (N dim; N-1 dense, 1 sparse) it could be temporarily reshaped to the sparse batch size (1 dim).
  2. The hybrid dense-sparse tensor could be reshaped back to its original state (N-1 dim) from its batched sparse representation (2 dim; 1 dense, 1 sparse) after the sparse dimension operations have been performed and the sparse dimension has been removed.

It appears that batched sparse tensors have never been implemented in pytorch;

B. hybrid dense-sparse tensor implementations

A possible implementation of hybrid dense-sparse tensor would be a dense tensor that contains an integer pointer to a sparse tensor (each sparse tensor contains elements with value/index as in the current pytorch sparse tensor implementation).

When the user requests a particular subset of the tensor (i.e. any operation that requires indexing dense dimensions/batch samples of the tensor e.g. t[3] or torch.min(t, dim=1)), it performs an effective gather (single or multi-index lookup) operation based on the indices to gather the relevant sparse tensors (elements).

Are there specific developers working on (hybrid) sparse tensor implementations?