Slice a sparse CSR tensor

I am storing a dataset as a sparse CSR tensor and I want to write a data loader which selects certain row indices and outputs a dense tensor.

Something like this:

class Dataset:
    def __init__(self, dataset: torch.sparse.FloatTensor, batch_size: int, shuffle: bool):
        self.dataset = dataset
        self.batch_size = batch_size
        self.inds = torch.arange(dataset.shape[0]).long()
        self.ptr = 0
        self.shuffle = shuffle

    def _reset(self):
        if self.shuffle:
            self.inds = self.inds[torch.randperm(len(self.inds))]
        self.ptr = 0
    def __iter__(self):
        return self

    def __next__(self):
        if self.ptr == len(self.inds):
            raise StopIteration()
        next_ptr = min(len(self.inds), self.ptr + self.batch_size)

        inds = self.inds[self.ptr:next_ptr]
        dense_tensor = self.dataset.index_select(0, inds).to_dense()
        self.ptr = next_ptr
        return (dense_tensor, inds)

if dataset is a COO sparse tensor, everything works fine, but it it is a CSR sparse tensor, then I get the following error:

NotImplementedError: Could not run 'aten::index_select' with arguments from the 'SparseCsrCPU' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit for possible resolutions. 'aten::index_select' is only available for these backends: [CPU, SparseCPU, BackendSelect, Named, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradXLA, UNKNOWN_TENSOR_TYPE_ID, AutogradMLC, AutogradHPU, AutogradNestedTensor, AutogradPrivateUse1, AutogradPrivateUse2, AutogradPrivateUse3, Tracer, Autocast, Batched, VmapMode].

Is index_select really not implemented for sparse CSR tensors? Isn’t that what CSR format is fast at? What’s the most efficient way for me to slice a sparse torch tensor… something that will work on CPU and GPU?

Related to this (Sparse tensor support for slice, reduce sum and element wise comparison?) but it’s been a while.


I’ve compared the following two strategies:

  1. keep data as a scipy.sparse.csr_matrix, do the slicing, make it dense, convert to torch, and then move to GPU
  2. keep data as a torch.sparse_coo on GPU, do the slicing, make it dense

Unless I’m doing something wrong, it seems like for a test sparse dataset (this is meant to be like single-cell RNA sequencing data… it’s 10k rows by 30k columns, with 900k nonzero elements), I see the following for the two strategies:

for d in dataloader:
  1. takes 60ms, with data on CPU being sliced and moved to GPU for each minibatch
  2. takes 600ms, with everything on GPU

I’d really love to be able to slice a torch.sparse CSR tensor on GPU as fast as scipy.sparse can slice a CSR matrix on CPU.

