Slice a sparse CSR tensor

I am storing a dataset as a sparse CSR tensor and I want to write a data loader which selects certain row indices and outputs a dense tensor.

Something like this:

class Dataset:
    def __init__(self, dataset: torch.sparse.FloatTensor, batch_size: int, shuffle: bool):
        self.dataset = dataset
        self.batch_size = batch_size
        self.inds = torch.arange(dataset.shape[0]).long()
        self.ptr = 0
        self.shuffle = shuffle

    def _reset(self):
        if self.shuffle:
            self.inds = self.inds[torch.randperm(len(self.inds))]
        self.ptr = 0
    def __iter__(self):
        return self

    def __next__(self):
        if self.ptr == len(self.inds):
            raise StopIteration()
        next_ptr = min(len(self.inds), self.ptr + self.batch_size)

        inds = self.inds[self.ptr:next_ptr]
        dense_tensor = self.dataset.index_select(0, inds).to_dense()
        self.ptr = next_ptr
        return (dense_tensor, inds)

if dataset is a COO sparse tensor, everything works fine, but it it is a CSR sparse tensor, then I get the following error:

NotImplementedError: Could not run 'aten::index_select' with arguments from the 'SparseCsrCPU' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit for possible resolutions. 'aten::index_select' is only available for these backends: [CPU, SparseCPU, BackendSelect, Named, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradXLA, UNKNOWN_TENSOR_TYPE_ID, AutogradMLC, AutogradHPU, AutogradNestedTensor, AutogradPrivateUse1, AutogradPrivateUse2, AutogradPrivateUse3, Tracer, Autocast, Batched, VmapMode].

Is index_select really not implemented for sparse CSR tensors? Isn’t that what CSR format is fast at? What’s the most efficient way for me to slice a sparse torch tensor… something that will work on CPU and GPU?

Related to this (Sparse tensor support for slice, reduce sum and element wise comparison?) but it’s been a while.



I’ve compared the following two strategies:

  1. keep data as a scipy.sparse.csr_matrix, do the slicing, make it dense, convert to torch, and then move to GPU
  2. keep data as a torch.sparse_coo on GPU, do the slicing, make it dense

Unless I’m doing something wrong, it seems like for a test sparse dataset (this is meant to be like single-cell RNA sequencing data… it’s 10k rows by 30k columns, with 900k nonzero elements), I see the following for the two strategies:

for d in dataloader:
  1. takes 60ms, with data on CPU being sliced and moved to GPU for each minibatch
  2. takes 600ms, with everything on GPU

I’d really love to be able to slice a torch.sparse CSR tensor on GPU as fast as scipy.sparse can slice a CSR matrix on CPU.

This is a bare-bones method I’ve been using for storing & slicing sparse data on the GPU. It produces a dense matrix when sliced, and can only be sliced using a tensor or numpy array of indices (not a single index, nor a list of indices).

This is functional, but I haven’t tried optimizing the cuda kernel at all.

import cutex
import torch
from torch import Tensor
from scipy.sparse import csr_matrix
import numpy as np

def expand_rows(X_indptr: Tensor,
                X_indices: Tensor,
                X_data: Tensor,
                which_rows: Tensor,
                ncols: int):
    row_starts = X_indptr[which_rows]
    row_nnz = X_indptr[which_rows + 1] - row_starts
    max_nnz = row_nnz.max()

    M = len(which_rows)
    N = max_nnz

    out = torch.zeros((M, ncols),
    gridDim = (cutex.ceildiv(N, 16), cutex.ceildiv(M, 16), 1)  # noqa
    blockDim = (16, 16, 1)  # noqa
    int n = blockIdx.x * blockDim.x + threadIdx.x;
    int m = blockIdx.y * blockDim.y + threadIdx.y;

    if (m >= M) return;
    if (n >= row_nnz[m]) return;
    int csr_offset = row_starts[m] + n;
    out[m][X_indices[csr_offset]] = X_data[csr_offset];
    """, boundscheck=False)
    return out

class sparse_cuda_matrix(object):
    def __init__(self, X):
        assert isinstance(X, csr_matrix)
        self.indptr = torch.IntTensor(X.indptr).cuda()
        self.indices = torch.IntTensor(X.indices).cuda() = torch.IntTensor(
        self.ncols = X.shape[1]
        self.shape = X.shape

    def __getitem__(self, indices):
        return expand_rows(self.indptr, self.indices,,
                           indices, self.ncols)

if __name__ == '__main__':
    import scanpy as sc
    import tqdm
    cts = sc.read_h5ad('CRISPR_all_days.h5ad')
    sparse_matrix = sparse_cuda_matrix(cts.X)

    for i in tqdm.tqdm(range(10000)):
        rows = np.random.randint(0, cts.X.shape[0], size=1024)
        s = sparse_matrix[rows].sum()
        # assert s == cts.X[rows, :].sum()

