I’d like to use the DataLoader and Dataset classes to train a NLP model. These classes are pretty useful, but when the dataloader returns a batch, it concats them with batch dim first. Is there a specific way to have the time axis first or should I just transpose the axis ?
And this is caused by the default_collate function. You can use the default_convert function as collate_fn to keep the axis:
from torch.utils.data._utils.collate import default_convert
from torch.utils.data import DataLoader
data = DataLoader(<dataset>, collate_fn=default_convert)