Hi,
I have wrote below code for understanding how WeightedRandomSampler works.
import torch
from torch.utils.data.sampler import Sampler
from torch.utils.data import TensorDataset as dset
inputs = torch.randn(100,1,10)
target = torch.floor(3*torch.rand(100))
trainData = dset(inputs, target)
num_sample = 3
weight = [0.2, 0.3, 0.7]
sampler = torch.utils.data.sampler.WeightedRandomSampler(weight, batch_size)
trainLoader = torch.utils.data.DataLoader(trainData, num_sample , shuffle=False, sampler=sampler)
for i, (inp, tar) in enumerate(trainLoader):
print(inp.size())
I have got 100 instances in my fake dataset. When I run above code, only 10 of them sampled from dataset and also the number of iteration is 4 for my run! Could you please help me how does it work?
Hey @smth. Thanks for your response. I have understood WeightedRandomSampler, But I did not understand what is the logic of the enumerate(trainLoader). How many times does for loop execute?
Been looking at the code in DataLoader and WeightedRandomSampler, I can’t see how it takes class labels into account. From the code comment “weights (sequence) : a sequence of weights, not necessary summing up to one”. Not very helpful really for someone who’s trying to learn torch. It looks like weights is a list of weights per data point in the data set we are drawing from, NOT a weight per class (which I initially, maybe carelessly, assumed). And if that’s the case, you’d have to write code that computes this weight per data point and somehow “attach” that weight to the data point, e.g. a text data point becomes (sentence, label, weight for sampling) UNLESS some order is implied on the data set before we can use WeightedRandomSampler.
While the code is fairly straight forward, the semantics around WeightedRandomSampler are not clear at all.
hi, i want to train a network using 3 dataset. i have created my custom dataset class, then i need to load 3 datasets with different ratio in 1 batch. does anyone have sample code for this application?
I agree that the docs are not specific enough here. I just spent the last 30 minutes figuring out that the weights are meant to be specified at a data point level, not a class level. In fact, I spent some time trying to figure out how WeightedRandomSampler knows what my class labels are.