Dataloader loads data very slow on sparse tensor

I think I’ve found some further improvements by using a BatchSampler. This allows us to remove some operations like vstack. I would love to hear some feedback if there are any potential issues with this update.

class SparseDataset2():
    """
    Custom Dataset class for scipy sparse matrix
    """
    def __init__(self, data:Union[np.ndarray, coo_matrix, csr_matrix], 
                 targets:Union[np.ndarray, coo_matrix, csr_matrix], 
                 transform:bool = None):
        
        # Transform data coo_matrix to csr_matrix for indexing
        if type(data) == coo_matrix:
            self.data = data.tocsr()
        else:
            self.data = data
            
        # Transform targets coo_matrix to csr_matrix for indexing
        if type(targets) == coo_matrix:
            self.targets = targets.tocsr()
        else:
            self.targets = targets
        
        self.transform = transform # Can be removed

    def __getitem__(self, index):
        return self.data[index], self.targets[index]

    def __len__(self):
        return self.data.shape[0]
      
def sparse_coo_to_tensor2(coo:coo_matrix):
    """
    Transform scipy coo matrix to pytorch sparse tensor
    """
    values = coo.data
    indices = (coo.row, coo.col) # np.vstack
    shape = coo.shape

    i = torch.LongTensor(indices)
    v = torch.DoubleTensor(values)
    s = torch.Size(shape)

    return torch.sparse.DoubleTensor(i, v, s)
    
def sparse_batch_collate2(batch): 
    """
    Collate function which to transform scipy coo matrix to pytorch sparse tensor
    """
    # batch[0] since it is returned as a one element list
    data_batch, targets_batch = batch[0]
    
    if type(data_batch[0]) == csr_matrix:
        data_batch = data_batch.tocoo() # removed vstack
        data_batch = sparse_coo_to_tensor2(data_batch)
    else:
        data_batch = torch.DoubleTensor(data_batch)

    if type(targets_batch[0]) == csr_matrix:
        targets_batch = targets_batch.tocoo() # removed vstack
        targets_batch = sparse_coo_to_tensor2(targets_batch)
    else:
        targets_batch = torch.DoubleTensor(targets_batch)
    return data_batch, targets_batch

I’m using DoubleTensor above instead of FloatTensor just because I need that for my particular data.

Rerunning your code for comparison:

from scipy.sparse import random
X = random(800000, 300, density=0.25)
y = np.arange(800000)
ds = SparseDataset(X, y)
dl = DataLoader(ds, 
                      batch_size = 1024, 
                shuffle = True,
                      collate_fn = sparse_batch_collate,
                      generator=torch.Generator(device='cuda'))

for x, y in tqdm(iter(dl)):
  pass
100%|██████████| 782/782 [01:09<00:00, 11.26it/s]

Updated code:

X = random(800000, 300, density=0.25)
y = np.arange(800000)
ds = SparseDataset(X, y)
sampler = torch.utils.data.sampler.BatchSampler(
    torch.utils.data.sampler.RandomSampler(ds,
                      generator=torch.Generator(device='cuda')),
    batch_size=1024,
    drop_last=False)
dl = DataLoader(ds, 
                      batch_size = 1, 
                      collate_fn = sparse_batch_collate2,
                      generator=torch.Generator(device='cuda'),
          sampler = sampler)

for x, y in tqdm(iter(dl)):
  pass
100%|██████████| 782/782 [00:11<00:00, 71.03it/s]

Approximately 6.3x speed up on the Databricks cluster I’m using.