How to train each batch to have the same class

I want to train each batch to have the same class. For example, when batch size is 32, update once with 32 data of class 0, and then learn with 32 data of randomly selected class.
If you learn all data in class 0 and then learn about class 1, it can’t be learning.

You can make custom Sampler based on torch.utils.data.WeightedRandomSampler.
I think you can set a class probability as 1 and rest of the classes 0, then in next iteration you can choose another class and set its probability 1 and rests 0, and so on.