Per-Class and per-samples WeightSampler explanation needed


Let me start by saying I’ve searched, and searched and then searched some more. Not only on PyTorch but on Github and other sources.
I believe what confused me is the fact I blindly copied samples from this forum.
I have read topics and posts such as:

There seem to be three issues I would really appreciate clarity on:

  1. Does the WeightSampler take as argument the length of the training set, or the batch size?
    The behaviour changes dramatically if I (wrongly I presume) pass the batch size
  2. How exactly should I calculate the weights for the WeightSampler? I’ve seen snippets from other users and from @ptrblck and they all vary to the point where I think this has made me calculate them wrongly (see the snippet I’m posting below which is the actual code I use).
  3. If I use a WeightedSampler does the CE Loss criterion and the way it is back-propagated change? I seen a snippet from here (again from @ptrblck) which suggests that in the event of using a WeightedSampler with CE Loss, then the backprop should be different? Furthermore in the same snippet, constructing the nn.CrossEntropyLoss object seems to take as parameters the actual weights?

The way I calculate the actual weights. Method class_samples() returns a list where each class/label has its rows in the dataset counted. Order is preserved and I have manually verified this.
Most examples I have seen stop at line #3 and don’t carry on with the cat operation.
EDIT I tried removing the bottom three lines ( and such) and the network Loss drops to zero but Top-1 gets stuck at 50% so I am assuming this is the correct approach creating weight samples.

sample_index    = dataset.class_samples()
num_samples    = len(dataset)
classes_weight = 1. / torch.tensor(sample_index, dtype=torch.float)
target                 = * 0.99), dtype=torch.long),
                                torch.ones(int(num_samples * 0.01), dtype=torch.long)))
samples_weight = torch.tensor([classes_weight[t] for t in target]) 

The remaining code is more or less templated:

  • I use a ResNet34 for single class classification of PDF forms (uploaded by users) which are very well structured most of the time
  • I have a highly unbalanced dataset, some classes have 4000 samples, some only have 40, hence the use of the WeightedSampler
  • I get very good Top-1 accuracy which is reproducible (99% to 100%) however when doing cross-evaluation the SoftMax’ed output finds and approximates similarity between unseen input and known input when it really shouldn’t. There seems to be a high variability (more than 30%) of the Top-1 when testing with unseen input. I would assume this is too much generalisation/approximation whilst the network reports over-fitting, which makes no sense to me.

Any help would be greatly appreciated, and hopefully would put the matter to rest.

  1. Usually you would set the length in WeightedRandomSampler to the length of your dataset. However, this is not a requirement and you could also use replacement=True and double the length, if that fits your use case.
  2. Your code looks alright. Basically each sample should get its own weight. The usual approach would be to use the corresponding class index to assign the weight to the sample. Could you please post the posts with varying implementations, as we should fix them?
  3. The weight argument in nn.CrossEntropyLoss is using a class weight on the loss and is orthogonal to weighted sampling. In the linked post, the user wanted a sample based loss weighting in its training, which is also a special case and unrelated to random sampling. (Although all approaches might be used to counter effects of an imbalanced dataset)

Hi @ptrblck Thank you for the reply (I owe you more than once at this point!)

  1. When you say replacement=True doubles the length, do you mean it will allow duplication (up-sampling) of the minority classes?
  2. That is what I have understood, although I’m not clear what does and why you assign 99% of them to zeros and 1% to ones. I’ll post the code below, including an explanation of the dataloader.
  3. Similarly that was what I understood, although I’m guessing how it works I’d like to be certain. Is the class weight being used to account for imbalance by correctly attributing propagation error in a better ratio? If so, shouldn’t the sampler account for that imbalance and make the use of weights unnecessary?
    I have seen snippets of yours were you set the CrossEntropyLoss weights, the reduction='none' but then I am not very clear how you use the criterion and loss to correctly update the algorithm.

The dataloader I have is using IP patent forms (uploaded by users) which are very structured documents, with small deviation (especially after resizing). To give you an idea, I’ve calculated the mean and std for the entire data-set to be (after resizing to 224x224):

RGB mean: 0.9497, 0.9497, 0.9497
RGB std: 0.0363, 0.0363, 0.0363

Which is kinda telling the story (since those values are normalised) most images are white (because of the background being a white page) and some sparse pixels are usually black, or less often coloured (logos, etc.). I don’t know if it would help pre-processing by removing the background or other techniques which may help more.

The loader places data tuples in a list of (input.cuda().float(), output.cuda().long()) where input is the tensorified RGB image in 224x224 which is only resized and normalised. The output is a fixed class C length tensor, where each output activates one value (only one can be active so no multiclass).
The actual labels are saved in an self.index = OrderedDict of tuples to actual output index, e.g., (country, form) => output_index for simplicity’s shake, so at any point I can retrieve the actual label by looking at the output and vice versa, all trivial stuff so far.

I calculate the class samples by using the following:

    def class_samples(self):
        label_samples = [None] * self.num_outputs()
        for i in range(len(self.index)):
            key        = self.get_label(i)
            count      = self.labels.count(key)
            label_samples[i] = count
        return label_samples

    def num_outputs(self):
        output_size = len(self.index)
        assert output_size > 0
        return output_size

    def get_index(self, item: tuple) -> int:
        return self.index.get(item)

    def get_label(self, idx) -> tuple:
        for key, value in self.index.items():
            if value == idx:
                return key
        return None

You’ve already seen how I calculate class weights and then sample weights:

    def weights(self):
        sample_index   = self.class_samples()
        num_samples    = len(self)
        classes_weight = 1. / torch.tensor(sample_index, dtype=torch.float)
        target         = * 0.99), dtype=torch.long),
                                    torch.ones(int(num_samples * 0.01), dtype=torch.long)))
        # I am unsure what this does, it appears to be populating the samples weight based
        # on the classes weights ? However, I would assume that
        # the correct way of doing this would be to assign each sample, the correct corresponding
        # weight, based on which class it belongs to.
        samples_weight = torch.tensor([classes_weight[t] for t in target])
        return (samples_weight, classes_weight)

Interestingly, when I unit test it, I see that one sample is not indexed/missing:

samples weights: torch.Size([4375]) 
classes weights: torch.Size([13])
total samples: 4376

The weights method is invoked before training:

(samples_w, classes_w) = dataset.weights()
classes_w = classes_w.cuda().float()

In the training script I do allow replacement and I do pass the class weights to CE Loss:

criterion = nn.CrossEntropyLoss(weight=classes_w, reduction='none').cuda()
sampler =,

Finally, when I do the back-prop:

for epoch in range(args.epochs):
    # batch/step iteration
    for i, (inputs, labels) in enumerate(train_loader):
        logits        = net(inputs)
        ideal         = labels.argmax(1)
        loss          = criterion(logits, ideal)

Any help is greatly appreciated!

1 Like