WeightedRandomSampler + DistributedSampler

Hi,

Is there any method that can sample with weights under the distributed case?

Thanks.

2 Likes

That’s an interesting use case.
You could probably write a custom sampler deriving from DistributedSampler and pass the weights as an extra argument.
Here you would probably have to add the “extra” weigths and this line of code could probably be replaced by this one.

2 Likes

@ptrblck Since torch.multinomial takes num_samples as the argument, how can we replace (1) with (2)?

1 Like

Hi @ptrblck could you please assist in providing full implementation for distributedSampler with predefined weights? thanks

I’ve posted an initial implementation here. However, note that this code is not fully tested for all possible use cases and I suggest to also take a look at this implementation.

Hi, @ptrblck I run your provided code. But it occurs an error. “AttributeError: ‘DistributedWeightedSampler’ object has no attribute ‘shuffle’”

Hi @Rabeeh_Karimi, have you solved your problem? If yes, can you provide your implementation?

I’m not sure which code snippet you are using. In case you are using the linked DistributedWeightedSampler, shuffling wouldn’t work, since you are using sample weights to draw the samples.

Many thanks for that valuable blog!
I could successfully implement the DistributedWeightedSampler with using MultiGPU training, but I recognised that the data per batch and GPU device are equal.
With the common DistributedSampler there were random data per batch and GPU.

For better understanding here some outputs during training
(a) with DistributedSampler
TRAIN on GPU:0: True Label tensor([5, 7, 5, 2, 1, 5, 8, 2], device=‘cuda:0’)
TRAIN on GPU:1: True Label tensor([8, 4, 0, 2, 1, 5, 8, 3], device=‘cuda:1’)
TRAIN on GPU:2: True Label tensor([2, 6, 7, 3, 5, 7, 5, 7], device=‘cuda:2’)
TRAIN on GPU:3: True Label tensor([2, 5, 2, 2, 2, 4, 1, 2], device=‘cuda:3’)

(b) with DistributedWeightedSampler
TRAIN on GPU:0: True Label tensor([2, 4, 1, 3, 6, 4, 4, 4], device=‘cuda:0’)
TRAIN on GPU:1: True Label tensor([2, 4, 1, 3, 6, 4, 4, 4], device=‘cuda:1’)
TRAIN on GPU:2: True Label tensor([2, 4, 1, 3, 6, 4, 4, 4], device=‘cuda:2’)
TRAIN on GPU:3: True Label tensor([2, 4, 1, 3, 6, 4, 4, 4], device=‘cuda:3’)

Code for DistributedWeightedSampler:

class DistributedWeightedSampler(Sampler):

def __init__(self, dataset, num_replicas: Optional[int] = None, rank: Optional[int] = None, shuffle: bool = True, seed: int = 0, replacement: bool = True):
    if num_replicas is None:
        if not dist.is_available():
            raise RuntimeError("Requires distributed package to be available")
        num_replicas = dist.get_world_size()
    if rank is None:
        if not dist.is_available():
            raise RuntimeError("Requires distributed package to be available")
        rank = dist.get_rank()
    if rank >= num_replicas or rank < 0:
        raise ValueError("Invalid rank {}, rank should be in the interval [0, {}]".format(rank, num_replicas - 1))
    self.dataset = dataset
    self.num_replicas = num_replicas
    self.rank = rank
    self.epoch = 0
    self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
    self.total_size = self.num_samples * self.num_replicas
    self.shuffle = shuffle
    self.seed = seed
    self.replacement = replacement #sample can be drown again in that row if True

def calculate_weights(self, targets):
    class_sample_count = np.array([len(np.where(self.dataset.data.y == t)[0]) for t in np.unique(self.dataset.data.y)])
    weight = 1. / class_sample_count
    samples_weight = np.array([weight[t] for t in self.dataset.data.y])
    samples_weight = torch.from_numpy(samples_weight)
    samples_weigth = samples_weight.double()
    return samples_weigth

def __iter__(self):
    # deterministically shuffle based on epoch
    if self.shuffle:
        g = torch.Generator()
        g.manual_seed(self.seed + self.epoch)
        indices = torch.randperm(len(self.dataset), generator=g).tolist()
    else:
        indices = list(range(len(self.dataset)))

    # add extra samples to make it evenly divisible
    indices += indices[:(self.total_size - len(indices))]
    assert len(indices) == self.total_size

    # subsample
    indices = indices[self.rank:self.total_size:self.num_replicas]
    assert len(indices) == self.num_samples

    # get targets (you can alternatively pass them in __init__, if this op is expensive)
    # data.data.y == labels
    targets = self.dataset.data.y
    targets = targets[self.rank:self.total_size:self.num_replicas]
    #assert len(targets) == self.num_samples
    weights = self.calculate_weights(targets)
    weighted_indices = torch.multinomial(weights, self.num_samples, self.replacement).tolist()

    return iter(weighted_indices)

def __len__(self):
    return self.num_samples

def set_epoch(self, epoch):
    self.epoch = epoch

@ptrblck: Do you have any advice fixing that issue?
I’d like to shuffle data per batch and GPU like while using DistributedSampler.

1 Like

It seems you are seeing each sampler with the same value, so drawing the same samples would be expected.