I would like to be able to sparsify a single (last) dimension of an N-dimensional dense tensor, as I need to be able to manipulate the sparse tensor using ordinary pytorch operations. An example of an ordinary operation I need to be able to perform to the tensor is normaliseTensor (min/max are unavailable);
min(dim)/max(dim) could be easily implemented using a hybrid dense/sparse tensor (where only the last dimension is sparsified).
Ideally the user should be able to select a dimension along which a tensor is sparsified (e.g. the last dimension of an N-dimensional dense tensor) and all the sparse matrix operations/optimisations should work under the hood with no change to the pytorch interface (they should look like dense tensors in pytorch) - hardware permitting.
Perhaps there is a way the existing pytorch sparse tensor types can be manipulated to emulate this functionality?
The problem could be solved with a batched sparse tensor implementation (a hybrid tensor where the first dimension is dense and the other(s) are sparse);
To multiply a dense tensor (of N-1 dim) with a batched sparse representation (2 dim; 1 dense, 1 sparse) of the underlying hybrid tensor (N dim; N-1 dense, 1 sparse) it could be temporarily reshaped to the sparse batch size (1 dim).
The hybrid dense-sparse tensor could be reshaped back to its original state (N-1 dim) from its batched sparse representation (2 dim; 1 dense, 1 sparse) after the sparse dimension operations have been performed and the sparse dimension has been removed.
It appears that batched sparse tensors have never been implemented in pytorch;
B. hybrid dense-sparse tensor implementations
A possible implementation of hybrid dense-sparse tensor would be a dense tensor that contains an integer pointer to a sparse tensor (each sparse tensor contains elements with value/index as in the current pytorch sparse tensor implementation).
When the user requests a particular subset of the tensor (i.e. any operation that requires indexing dense dimensions/batch samples of the tensor e.g. t[3] or torch.min(t, dim=1)), it performs an effective gather (single or multi-index lookup) operation based on the indices to gather the relevant sparse tensors (elements).
Are there specific developers working on (hybrid) sparse tensor implementations?