I wrote a quick and dirty DataLoader-like objects trying to operationalize this: instead of costly DataLoader collating, either (1) use index_select
, or (2) shuffle the tensors in-place beforehand, then slice. In my setting, index_select
was better, though the improvement of both wasn’t super dramatic (shuffling is costly!). Use as follows:
a = torch.arange(10)
b = torch.arange(10, 20)
c = torch.arange(20, 30)
dataloader = FastTensorDataLoader(a, b, c, batch_size=3, shuffle=True)
print(len(dataloader))
for batch in dataloader:
print(batch)
The index_select
version:
class FastTensorDataLoader:
"""
A DataLoader-like object for a set of tensors that can be much faster than
TensorDataset + DataLoader because dataloader grabs individual indices of
the dataset and calls cat (slow).
"""
def __init__(self, *tensors, batch_size=32, shuffle=False):
"""
Initialize a FastTensorDataLoader.
:param *tensors: tensors to store. Must have the same length @ dim 0.
:param batch_size: batch size to load.
:param shuffle: if True, shuffle the data *in-place* whenever an
iterator is created out of this object.
:returns: A FastTensorDataLoader.
"""
assert all(t.shape[0] == tensors[0].shape[0] for t in tensors)
self.tensors = tensors
self.dataset_len = self.tensors[0].shape[0]
self.batch_size = batch_size
self.shuffle = shuffle
# Calculate # batches
n_batches, remainder = divmod(self.dataset_len, self.batch_size)
if remainder > 0:
n_batches += 1
self.n_batches = n_batches
def __iter__(self):
if self.shuffle:
self.indices = torch.randperm(self.dataset_len)
else:
self.indices = None
self.i = 0
return self
def __next__(self):
if self.i >= self.dataset_len:
raise StopIteration
if self.indices is not None:
indices = self.indices[self.i:self.i+self.batch_size]
batch = tuple(torch.index_select(t, 0, indices) for t in self.tensors)
else:
batch = tuple(t[self.i:self.i+self.batch_size] for t in self.tensors)
self.i += self.batch_size
return batch
def __len__(self):
return self.n_batches
For the shuffle in-place version, replace the __iter__
and __next__
functions: (this shuffles the underlying tensors, so be careful)
def __iter__(self):
if self.shuffle:
r = torch.randperm(self.dataset_len)
self.tensors = [t[r] for t in self.tensors]
self.i = 0
return self
def __next__(self):
if self.i >= self.dataset_len:
raise StopIteration
batch = tuple(t[self.i:self.i+self.batch_size] for t in self.tensors)
self.i += self.batch_size
return batch