Here I come to answer my own question. I come up with a simple solution after researching for awhile.
Instead of feeding PyTorch sparse tensor directly into the dataloader, I wrote a custom Dataset class which only accept scipy coo_matrix or equivalent. Then, I wrote a custom collate function for the dataloader which to transform scipy coo_matrix to pytorch sparse tensor during data loading.
Here’s the code
from typing import Union
import numpy as np
import torch
from import Dataset, DataLoader
from scipy.sparse import (random,
from tqdm import tqdm
class SparseDataset(Dataset):
Custom Dataset class for scipy sparse matrix
def __init__(self, data:Union[np.ndarray, coo_matrix, csr_matrix],
targets:Union[np.ndarray, coo_matrix, csr_matrix],
transform:bool = None):
# Transform data coo_matrix to csr_matrix for indexing
if type(data) == coo_matrix: = data.tocsr()
else: = data
# Transform targets coo_matrix to csr_matrix for indexing
if type(targets) == coo_matrix:
self.targets = targets.tocsr()
self.targets = targets
self.transform = transform # Can be removed
def __getitem__(self, index:int):
return[index], self.targets[index]
def __len__(self):
def sparse_coo_to_tensor(coo:coo_matrix):
Transform scipy coo matrix to pytorch sparse tensor
values =
indices = np.vstack((coo.row, coo.col))
shape = coo.shape
i = torch.LongTensor(indices)
v = torch.FloatTensor(values)
s = torch.Size(shape)
return torch.sparse.FloatTensor(i, v, s)
def sparse_batch_collate(batch:list):
Collate function which to transform scipy coo matrix to pytorch sparse tensor
data_batch, targets_batch = zip(*batch)
if type(data_batch[0]) == csr_matrix:
data_batch = vstack(data_batch).tocoo()
data_batch = sparse_coo_to_tensor(data_batch)
data_batch = torch.FloatTensor(data_batch)
if type(targets_batch[0]) == csr_matrix:
targets_batch = vstack(targets_batch).tocoo()
targets_batch = sparse_coo_to_tensor(targets_batch)
targets_batch = torch.FloatTensor(targets_batch)
return data_batch, targets_batch
X = random(800000, 300, density=0.25)
y = np.arange(800000)
train_ds = SparseDataset(X, y)
train_dl = DataLoader(train_ds,
batch_size = 1024,
collate_fn = sparse_batch_collate)
for x_batch, y_batch in tqdm(train_dl):
# Output
# 100%|██████████| 782/782 [01:00<00:00, 12.84it/s]
As you can see the whole data loading process for 800000 data points with 300 sparse features (batch size = 1024) uses only about 1 minute (12.84 batches per second), which is quite fast.
Still, I’m not sure whether this is the correct approach but I 'll stick to this at the moment. Any other solutions are still welcome.