DataLoader: one input class in each batch?

Is there a way to make the DataLoader produce batches containing only one class each?

For example, the training dataset could contain 1000 images A, 1500 images B, 1300 images C…

I would need the DataLoader to then yield a batch of only A samples, then another batch of B samples, etc.

Thank you so much for any help or tips!

I think the easiest way would be to implement a custom sampler and use it in the DataLoader.
This sampler could take the batch size and create data indices for the same class for each call.
Your Dataset.__getitem__ method would therefore get a batch of indices and you would have to load the complete batch there.

Alternatively, you could also try to make sure that each subsequent call to the sampler only returns the “expected” value (single index), so that the collate_fn will create a batch of the same class.

1 Like

That makes a ton of sense - after researching a bit, I think I will go with the approach where I create a separate Dataset for each input class, wrapped in a ConcatDataset. Then the DataLoader would call on a custom sampler which would alternate between pulling from the different sub-Datasets…
Thanks @ptrblck!

1 Like