I also had the same issue with Weighted Random Sampler and I found This reply from @ptrblck to be the best one
So based on the source code it would change like this
class DistributedSubset(Sampler):
"""
https://discuss.pytorch.org/t/how-to-use-my-own-sampler-when-i-already-use-distributedsampler/62143/8
"""
#It’s common to call the total number of processes the world size
def __init__(self,indices):
self.indices = indices
def __iter__(self):
# deterministically shuffle based on epoch
x=[self.indices[i] for i in torch.randperm(len(self.indices))]
return iter(x)
def __len__(self):
return len(self.indices)
def set_epoch(self, epoch):
self.epoch = epoch
Just make sure you use set_epoch() for every epoch to shuffle the data
Let me know if it was helpful !