DataLoader transpose axis for text?

Hello,

I’d like to use the DataLoader and Dataset classes to train a NLP model. These classes are pretty useful, but when the dataloader returns a batch, it concats them with batch dim first. Is there a specific way to have the time axis first or should I just transpose the axis ?

Thanks (:

1 Like

It’s transposed because:

the DataLoader loads the data in a systematic way such that we stack data vertically instead of horizontally.This is particularly useful for flowing batches of tensors as tensors stack vertically (i.e. in the first dimension) to form batches.
(Building Efficient Custom Datasets in PyTorch | by Syafiq Kamarul Azman | Towards Data Science)

And this is caused by the default_collate function. You can use the default_convert function as collate_fn to keep the axis:

from torch.utils.data._utils.collate import default_convert
from torch.utils.data import DataLoader

data = DataLoader(<dataset>, collate_fn=default_convert)
1 Like