One sample from each class in every batch

Hi, I’m training a classification model with large number of classes (40000), and I want that in each batch , there will be no more than one sample from each class.
I’ve tried to create a custom sampler , but it had to much of “edge cases”, I tried to make things after I already got the batch “in my hands”, this suffered from the same issue.
Do you guys have any idea? on how to simply do that?
Of course that my batch << number of classes.

Thanks for you help!