Does WeightedRandomSampler increase interations?

Sure, this code shows how to use a per-batch weighting.
However, note that you will get the same results as directly passing the weight to the criterion, if you don’t change the weights based on the current batch:

class_weights = torch.tensor([0.1, 0.5, 0.4])
output = torch.randn(10, 3, requires_grad=True)
target = torch.randint(0, 3, (10,))

criterion = nn.CrossEntropyLoss(reduction='none')
loss = criterion(output, target)
loss = loss * class_weights[target]
loss = (loss / class_weights[target].sum()).sum()

weighted_criterion = nn.CrossEntropyLoss(weight=class_weights)
weighted_loss = weighted_criterion(output, target)
print(torch.allclose(weighted_loss, loss))
> True
1 Like