Sparse Tensor implementation (sparsify last dimension only)

I would like to be able to sparsify a single (last) dimension of an N-dimensional dense tensor, as I need to be able to manipulate the sparse tensor using ordinary pytorch operations. An example of an ordinary operation I need to be able to perform to the tensor is normaliseTensor (min/max are unavailable);

def normaliseTensor(tensor):
    min_vals, _ = torch.min(tensor, dim=-1, keepdim=True)
    max_vals, _ = torch.max(tensor, dim=-1, keepdim=True)
    epsilon = 1e-8  # Small epsilon value
    tensor = (tensor - min_vals) / (max_vals - min_vals + epsilon)
    return tensor

min(dim)/max(dim) could be easily implemented using a hybrid dense/sparse tensor (where only the last dimension is sparsified).

Ideally the user should be able to select a dimension along which a tensor is sparsified (e.g. the last dimension of an N-dimensional dense tensor) and all the sparse matrix operations/optimisations should work under the hood with no change to the pytorch interface (they should look like dense tensors in pytorch) - hardware permitting.

Perhaps there is a way the existing pytorch sparse tensor types can be manipulated to emulate this functionality?

A. batched sparse tensors

The problem could be solved with a batched sparse tensor implementation (a hybrid tensor where the first dimension is dense and the other(s) are sparse);

  1. To multiply a dense tensor (of N-1 dim) with a batched sparse representation (2 dim; 1 dense, 1 sparse) of the underlying hybrid tensor (N dim; N-1 dense, 1 sparse) it could be temporarily reshaped to the sparse batch size (1 dim).
  2. The hybrid dense-sparse tensor could be reshaped back to its original state (N-1 dim) from its batched sparse representation (2 dim; 1 dense, 1 sparse) after the sparse dimension operations have been performed and the sparse dimension has been removed.

It appears that batched sparse tensors have never been implemented in pytorch;

B. hybrid dense-sparse tensor implementations

A possible implementation of hybrid dense-sparse tensor would be a dense tensor that contains an integer pointer to a sparse tensor (each sparse tensor contains elements with value/index as in the current pytorch sparse tensor implementation).

When the user requests a particular subset of the tensor (i.e. any operation that requires indexing dense dimensions/batch samples of the tensor e.g. t[3] or torch.min(t, dim=1)), it performs an effective gather (single or multi-index lookup) operation based on the indices to gather the relevant sparse tensors (elements).

Are there specific developers working on (hybrid) sparse tensor implementations?