Dataloader much slower than manual batching

Hi

I was trying to use dataloader to enumerate my training samples but I don’t understand why it is slower than “manual batching”

"Manual batching":

samples_tensor = torch.tensor(samples, dtype=torch.float).cuda()
labels_tensor = torch.tensor(labels, dtype=torch.long).cuda()

for e in range(nbEpochs):
        for b in range(nbSamples // batch_size):
            x = samples_tensor[b * batch_size:(b+1)*batch_size]
            y = labels_tensor[b * batch_size:(b+1)*batch_size]

"With dataloader":

from torch.utils.data import DataLoader
import torch.utils.data as utils

samples_tensor = torch.tensor(samples, dtype=torch.float).cuda()
labels_tensor = torch.tensor(labels, dtype=torch.long).cuda()

dset = utils.TensorDataset(samples_tensor, labels_tensor)
data_train_loader = DataLoader(dset, batch_size=1000, shuffle=True)

for e in range(nbEpochs):
        for _, (x,y) in enumerate(data_train_loader):
            pass

the variant with dataloader is MUCH slower than the manual process. Am I missing something?

Thanks

Because one is shuffled and the other one is not.

Thanks for your reply.

Even if I manually shuffle the tensors it stays much faster than the dataloader

Oh I know why. So for dataloader, since all your dataset give is the __getitem__, what it does is to retrieve a bunch of tensors at different indices, and then cat them. It is done this way so it can be very general and work on any dataset.

However, in case of TensorDataset, you have all data in memory, and can do much more efficiently. You may

  1. Shuffle the entire tensor before hand, and then do contiguous slicing
  2. Or slower, but still faster than the DataLoader way, use index_select (or advanced indexing).

My guess is that if your code mimic the DataLoader behavior, they will be of similar speed.

2 Likes

Thanks for the reply.
I have the same problem as the OP had, but I still don’t understand how to gain more SPEEED in a dataloader work process :frowning_face:
You told OP to shuffle before hand and just do a normal slicing and I have some questions:

  1. How do you effectively shuffle a tensor, or a (data,label) pair tensor?
  2. In what cases one should use Dataloaders and not custom batch generator function?

I wrote a quick and dirty DataLoader-like objects trying to operationalize this: instead of costly DataLoader collating, either (1) use index_select, or (2) shuffle the tensors in-place beforehand, then slice. In my setting, index_select was better, though the improvement of both wasn’t super dramatic (shuffling is costly!). Use as follows:

a = torch.arange(10)
b = torch.arange(10, 20)
c = torch.arange(20, 30)

dataloader = FastTensorDataLoader(a, b, c, batch_size=3, shuffle=True)

print(len(dataloader))
for batch in dataloader:
    print(batch)

The index_select version:

class FastTensorDataLoader:
    """
    A DataLoader-like object for a set of tensors that can be much faster than
    TensorDataset + DataLoader because dataloader grabs individual indices of
    the dataset and calls cat (slow).
    """
    def __init__(self, *tensors, batch_size=32, shuffle=False):
        """
        Initialize a FastTensorDataLoader.

        :param *tensors: tensors to store. Must have the same length @ dim 0.
        :param batch_size: batch size to load.
        :param shuffle: if True, shuffle the data *in-place* whenever an
            iterator is created out of this object.

        :returns: A FastTensorDataLoader.
        """
        assert all(t.shape[0] == tensors[0].shape[0] for t in tensors)
        self.tensors = tensors

        self.dataset_len = self.tensors[0].shape[0]
        self.batch_size = batch_size
        self.shuffle = shuffle

        # Calculate # batches
        n_batches, remainder = divmod(self.dataset_len, self.batch_size)
        if remainder > 0:
            n_batches += 1
        self.n_batches = n_batches

    def __iter__(self):
        if self.shuffle:
            self.indices = torch.randperm(self.dataset_len)
        else:
            self.indices = None
        self.i = 0
        return self

    def __next__(self):
        if self.i >= self.dataset_len:
            raise StopIteration
        if self.indices is not None:
            indices = self.indices[self.i:self.i+self.batch_size]
            batch = tuple(torch.index_select(t, 0, indices) for t in self.tensors)
        else:
            batch = tuple(t[self.i:self.i+self.batch_size] for t in self.tensors)
        self.i += self.batch_size
        return batch

    def __len__(self):
        return self.n_batches

For the shuffle in-place version, replace the __iter__ and __next__ functions: (this shuffles the underlying tensors, so be careful)

    def __iter__(self):
        if self.shuffle:
            r = torch.randperm(self.dataset_len)
            self.tensors = [t[r] for t in self.tensors]
        self.i = 0
        return self

    def __next__(self):
        if self.i >= self.dataset_len:
            raise StopIteration
        batch = tuple(t[self.i:self.i+self.batch_size] for t in self.tensors)
        self.i += self.batch_size
        return batch
4 Likes

Thanks for sharing!
There is a num worker mechanism in pytorch dataloader for multi-thread data loading.
So, how to set num workers in FastTensorDataLoader for multi-thread data loading?

1 Like