DataLoader batch first

Hey everyone,

I am gonna use torch.DataLoader for RNNs but it forms batches as batch first: [batch x seq x data]. However I would like to have the data in default pytorch RNN format: [seq x batch x data].
I could not see any easy way to do like batch_first=False

Could you please help me with?

The dataloader is just a utility used to work with the dataset class (passed as the first argument to the dataloader) Check out the dataset class that you pass to the dataloader. That would be where you’d have coded in the dimensions of the data returned by dataset class and thus your dataloader.

If for whatever reason this dataset class isn’t accessible to you; check out the collate function : How to create a dataloader with variable-size input and input.permute(1, 0, 2) your batches there.

Check out this useful answers on using RNNs with PyTorch

1 Like

I ran into a similar issue and made a mistake of using torch.reshape(x,(seq,batch,data)). But this resulted in messing the order of my time series data. The way I fixed was using torch.permute(x,(1,0,2))

1 Like