Dataloader loads data very slow on sparse tensor

I currently have about 1 million data points with 3000 sparse features. At first, I thought that PyTorch sparse tensor would be useful in this case, but I noticed that data loading on sparse tensor was very slow while using dataloader.

Here’s an example on my current situation.

import time
import torch
from torch.utils.data import TensorDataset, DataLoader

x = torch.FloatTensor(800000, 300).random_(0,5).to_sparse()
y = torch.FloatTensor(800000, 1).random_(0,5)

train_ds = TensorDataset(x, y)
train_dl = DataLoader(train_ds, batch_size = 1024)

start_time = time.time()
for x,y in train_dl:
    print("--- %s seconds ---" % (time.time() - start_time))
    break

# Output:
# --- 497.44726943969727 seconds ---

As you can see, the data loading on sparse matrix was very slow (1 batch used about 497 seconds) and it is even worse in my use case with 3000 sparse features.

So is there a way to speed up data loading on large sparse matrix while using dataloader? Thanks

1 Like

Here I come to answer my own question. I come up with a simple solution after researching for awhile.

Instead of feeding PyTorch sparse tensor directly into the dataloader, I wrote a custom Dataset class which only accept scipy coo_matrix or equivalent. Then, I wrote a custom collate function for the dataloader which to transform scipy coo_matrix to pytorch sparse tensor during data loading.

Here’s the code

from typing import Union

import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from scipy.sparse import (random, 
                          coo_matrix,
                          csr_matrix, 
                          vstack)
from tqdm import tqdm

class SparseDataset(Dataset):
    """
    Custom Dataset class for scipy sparse matrix
    """
    def __init__(self, data:Union[np.ndarray, coo_matrix, csr_matrix], 
                 targets:Union[np.ndarray, coo_matrix, csr_matrix], 
                 transform:bool = None):
        
        # Transform data coo_matrix to csr_matrix for indexing
        if type(data) == coo_matrix:
            self.data = data.tocsr()
        else:
            self.data = data
            
        # Transform targets coo_matrix to csr_matrix for indexing
        if type(targets) == coo_matrix:
            self.targets = targets.tocsr()
        else:
            self.targets = targets
        
        self.transform = transform # Can be removed

    def __getitem__(self, index:int):
        return self.data[index], self.targets[index]

    def __len__(self):
        return self.data.shape[0]

def sparse_coo_to_tensor(coo:coo_matrix):
    """
    Transform scipy coo matrix to pytorch sparse tensor
    """
    values = coo.data
    indices = np.vstack((coo.row, coo.col))
    shape = coo.shape

    i = torch.LongTensor(indices)
    v = torch.FloatTensor(values)
    s = torch.Size(shape)

    return torch.sparse.FloatTensor(i, v, s)
    
def sparse_batch_collate(batch:list): 
    """
    Collate function which to transform scipy coo matrix to pytorch sparse tensor
    """
    data_batch, targets_batch = zip(*batch)
    if type(data_batch[0]) == csr_matrix:
        data_batch = vstack(data_batch).tocoo()
        data_batch = sparse_coo_to_tensor(data_batch)
    else:
        data_batch = torch.FloatTensor(data_batch)

    if type(targets_batch[0]) == csr_matrix:
        targets_batch = vstack(targets_batch).tocoo()
        targets_batch = sparse_coo_to_tensor(targets_batch)
    else:
        targets_batch = torch.FloatTensor(targets_batch)
    return data_batch, targets_batch

Example

X = random(800000, 300, density=0.25)
y = np.arange(800000)
train_ds = SparseDataset(X, y)
train_dl = DataLoader(train_ds, 
                      batch_size = 1024, 
                      collate_fn = sparse_batch_collate)

for x_batch, y_batch in tqdm(train_dl):
    pass

# Output
# 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 782/782 [01:00<00:00, 12.84it/s]

As you can see the whole data loading process for 800000 data points with 300 sparse features (batch size = 1024) uses only about 1 minute (12.84 batches per second), which is quite fast.

Still, I’m not sure whether this is the correct approach but I 'll stick to this at the moment. Any other solutions are still welcome.

4 Likes

This is exactly what I have been looking for. Thank you for this!

I think I’ve found some further improvements by using a BatchSampler. This allows us to remove some operations like vstack. I would love to hear some feedback if there are any potential issues with this update.

class SparseDataset2():
    """
    Custom Dataset class for scipy sparse matrix
    """
    def __init__(self, data:Union[np.ndarray, coo_matrix, csr_matrix], 
                 targets:Union[np.ndarray, coo_matrix, csr_matrix], 
                 transform:bool = None):
        
        # Transform data coo_matrix to csr_matrix for indexing
        if type(data) == coo_matrix:
            self.data = data.tocsr()
        else:
            self.data = data
            
        # Transform targets coo_matrix to csr_matrix for indexing
        if type(targets) == coo_matrix:
            self.targets = targets.tocsr()
        else:
            self.targets = targets
        
        self.transform = transform # Can be removed

    def __getitem__(self, index):
        return self.data[index], self.targets[index]

    def __len__(self):
        return self.data.shape[0]
      
def sparse_coo_to_tensor2(coo:coo_matrix):
    """
    Transform scipy coo matrix to pytorch sparse tensor
    """
    values = coo.data
    indices = (coo.row, coo.col) # np.vstack
    shape = coo.shape

    i = torch.LongTensor(indices)
    v = torch.DoubleTensor(values)
    s = torch.Size(shape)

    return torch.sparse.DoubleTensor(i, v, s)
    
def sparse_batch_collate2(batch): 
    """
    Collate function which to transform scipy coo matrix to pytorch sparse tensor
    """
    # batch[0] since it is returned as a one element list
    data_batch, targets_batch = batch[0]
    
    if type(data_batch[0]) == csr_matrix:
        data_batch = data_batch.tocoo() # removed vstack
        data_batch = sparse_coo_to_tensor2(data_batch)
    else:
        data_batch = torch.DoubleTensor(data_batch)

    if type(targets_batch[0]) == csr_matrix:
        targets_batch = targets_batch.tocoo() # removed vstack
        targets_batch = sparse_coo_to_tensor2(targets_batch)
    else:
        targets_batch = torch.DoubleTensor(targets_batch)
    return data_batch, targets_batch

I’m using DoubleTensor above instead of FloatTensor just because I need that for my particular data.

Rerunning your code for comparison:

from scipy.sparse import random
X = random(800000, 300, density=0.25)
y = np.arange(800000)
ds = SparseDataset(X, y)
dl = DataLoader(ds, 
                      batch_size = 1024, 
                shuffle = True,
                      collate_fn = sparse_batch_collate,
                      generator=torch.Generator(device='cuda'))

for x, y in tqdm(iter(dl)):
  pass
100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 782/782 [01:09<00:00, 11.26it/s]

Updated code:

X = random(800000, 300, density=0.25)
y = np.arange(800000)
ds = SparseDataset(X, y)
sampler = torch.utils.data.sampler.BatchSampler(
    torch.utils.data.sampler.RandomSampler(ds,
                      generator=torch.Generator(device='cuda')),
    batch_size=1024,
    drop_last=False)
dl = DataLoader(ds, 
                      batch_size = 1, 
                      collate_fn = sparse_batch_collate2,
                      generator=torch.Generator(device='cuda'),
          sampler = sampler)

for x, y in tqdm(iter(dl)):
  pass
100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 782/782 [00:11<00:00, 71.03it/s]

Approximately 6.3x speed up on the Databricks cluster I’m using.

Oops, the second chunk of code above should use SparseDataset2.

If running on GPU, try the following to convert to tensor. It should be much faster!

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def sparse_coo_to_tensor(coo:coo_matrix):
    """
    Transform scipy coo matrix to pytorch sparse tensor
    """
    values = coo.data
    indices = np.vstack((coo.row, coo.col))
    shape = coo.shape

    i = torch.LongTensor(indices).to(DEVICE)
    v = torch.FloatTensor(values).to(DEVICE)
    s = torch.Size(shape)

    return torch.sparse.FloatTensor(i, v, s).to(DEVICE)