Unevenly distributed IterableDataset of batch size 1

I’m training LLM with multiple machines.
Due to the memory limit, I set batch size as 1.
Training data is distributed through split_dataset_by_node.
The problem is that some process is finished 1 step earlier than the others and the other process can’t complete training as they cannot calculate gradient.
For example, when the size of data is 5, and the data is distributed into 4 GPUs,
they will have [[0],[4]], [[1], []], [[2],[]], [[3],[]] data.
So the first process waits for the others on second step but they are actually completed and escaped training loop.
Is there any kind of solution?

Use a DistributedSampler which solves this issue by repeating samples or slicing the dataset on each rank to create the same lengths.

Thanks for suggestion.
Does it work on IterableDataset?

I got TypeError: object of type 'IterableDataset' has no len() because IterableDataset doesn’t have len().

No, since a defined length is needed to split the datasets. If you are streaming data you might need to add a custom logic making sure each rank receives the same amount of samples.

Thanks. I think IterableDatasetShard from transformers library works in this setting.

That’s interesting as I’m not familiar with this class. Share an update if you’ve found a working solution if possible.