I think the easiest way would be to implement a custom sampler and use it in the DataLoader.
This sampler could take the batch size and create data indices for the same class for each call.
Your Dataset.__getitem__ method would therefore get a batch of indices and you would have to load the complete batch there.
Alternatively, you could also try to make sure that each subsequent call to the sampler only returns the “expected” value (single index), so that the collate_fn will create a batch of the same class.
That makes a ton of sense - after researching a bit, I think I will go with the approach where I create a separate Dataset for each input class, wrapped in a ConcatDataset. Then the DataLoader would call on a custom sampler which would alternate between pulling from the different sub-Datasets…
Thanks @ptrblck!