# How does WeightedRandomSampler work?

Hi,
I have wrote below code for understanding how `WeightedRandomSampler` works.

``````import torch
from torch.utils.data.sampler import Sampler
from torch.utils.data import TensorDataset as dset

inputs = torch.randn(100,1,10)
target = torch.floor(3*torch.rand(100))
trainData = dset(inputs, target)

num_sample = 3
weight = [0.2, 0.3, 0.7]
sampler = torch.utils.data.sampler.WeightedRandomSampler(weight, batch_size)

for i, (inp, tar) in enumerate(trainLoader):
print(inp.size())
``````

I have got 100 instances in my fake dataset. When I run above code, only 10 of them sampled from dataset and also the number of iteration is 4 for my run! Could you please help me how does it work?

Best

5 Likes

it;'s a very small function, have a look at how it works: https://github.com/pytorch/pytorch/blob/master/torch/utils/data/sampler.py#L73-L90

1 Like

Hey @smth. Thanks for your response. I have understood `WeightedRandomSampler`, But I did not understand what is the logic of the enumerate(trainLoader). How many times does for loop execute?

1 Like

Been looking at the code in DataLoader and WeightedRandomSampler, I can’t see how it takes class labels into account. From the code comment “weights (sequence) : a sequence of weights, not necessary summing up to one”. Not very helpful really for someone who’s trying to learn torch. It looks like weights is a list of weights per data point in the data set we are drawing from, NOT a weight per class (which I initially, maybe carelessly, assumed). And if that’s the case, you’d have to write code that computes this weight per data point and somehow “attach” that weight to the data point, e.g. a text data point becomes (sentence, label, weight for sampling) UNLESS some order is implied on the data set before we can use WeightedRandomSampler.

While the code is fairly straight forward, the semantics around WeightedRandomSampler are not clear at all.

11 Likes

Totally agreed, i’d found it today and it is counter-intuitive

your understanding may be correct,pls reference here ,perhps it helps.

hi, i want to train a network using 3 dataset. i have created my custom dataset class, then i need to load 3 datasets with different ratio in 1 batch. does anyone have sample code for this application?

1 Like

I agree that the docs are not specific enough here. I just spent the last 30 minutes figuring out that the weights are meant to be specified at a data point level, not a class level. In fact, I spent some time trying to figure out how `WeightedRandomSampler` knows what my class labels are.

1 Like

Hi! Any update on this?

The documentation claims:

Samples elements from `[0,..,len(weights)-1]` with given probabilities (weights).

I agree that this might not be a sufficiently detailed explanations for some users and I’m sure PRs are welcome to improve the docs 1 Like