I think I’ve found some further improvements by using a BatchSampler
. This allows us to remove some operations like vstack
. I would love to hear some feedback if there are any potential issues with this update.
class SparseDataset2():
"""
Custom Dataset class for scipy sparse matrix
"""
def __init__(self, data:Union[np.ndarray, coo_matrix, csr_matrix],
targets:Union[np.ndarray, coo_matrix, csr_matrix],
transform:bool = None):
# Transform data coo_matrix to csr_matrix for indexing
if type(data) == coo_matrix:
self.data = data.tocsr()
else:
self.data = data
# Transform targets coo_matrix to csr_matrix for indexing
if type(targets) == coo_matrix:
self.targets = targets.tocsr()
else:
self.targets = targets
self.transform = transform # Can be removed
def __getitem__(self, index):
return self.data[index], self.targets[index]
def __len__(self):
return self.data.shape[0]
def sparse_coo_to_tensor2(coo:coo_matrix):
"""
Transform scipy coo matrix to pytorch sparse tensor
"""
values = coo.data
indices = (coo.row, coo.col) # np.vstack
shape = coo.shape
i = torch.LongTensor(indices)
v = torch.DoubleTensor(values)
s = torch.Size(shape)
return torch.sparse.DoubleTensor(i, v, s)
def sparse_batch_collate2(batch):
"""
Collate function which to transform scipy coo matrix to pytorch sparse tensor
"""
# batch[0] since it is returned as a one element list
data_batch, targets_batch = batch[0]
if type(data_batch[0]) == csr_matrix:
data_batch = data_batch.tocoo() # removed vstack
data_batch = sparse_coo_to_tensor2(data_batch)
else:
data_batch = torch.DoubleTensor(data_batch)
if type(targets_batch[0]) == csr_matrix:
targets_batch = targets_batch.tocoo() # removed vstack
targets_batch = sparse_coo_to_tensor2(targets_batch)
else:
targets_batch = torch.DoubleTensor(targets_batch)
return data_batch, targets_batch
I’m using DoubleTensor
above instead of FloatTensor
just because I need that for my particular data.
Rerunning your code for comparison:
from scipy.sparse import random
X = random(800000, 300, density=0.25)
y = np.arange(800000)
ds = SparseDataset(X, y)
dl = DataLoader(ds,
batch_size = 1024,
shuffle = True,
collate_fn = sparse_batch_collate,
generator=torch.Generator(device='cuda'))
for x, y in tqdm(iter(dl)):
pass
100%|██████████| 782/782 [01:09<00:00, 11.26it/s]
Updated code:
X = random(800000, 300, density=0.25)
y = np.arange(800000)
ds = SparseDataset(X, y)
sampler = torch.utils.data.sampler.BatchSampler(
torch.utils.data.sampler.RandomSampler(ds,
generator=torch.Generator(device='cuda')),
batch_size=1024,
drop_last=False)
dl = DataLoader(ds,
batch_size = 1,
collate_fn = sparse_batch_collate2,
generator=torch.Generator(device='cuda'),
sampler = sampler)
for x, y in tqdm(iter(dl)):
pass
100%|██████████| 782/782 [00:11<00:00, 71.03it/s]
Approximately 6.3x speed up on the Databricks cluster I’m using.