Is there a limit on how disbalanced a train set can be?

I used 0.4.0 and tried to stay as close to your code as possible.
This code works on my machine:

data = torch.randn(1000, 1)
target = torch.cat((
    torch.zeros(998),
    torch.ones(1),
    torch.ones(1)*2
)).long()

cls_weights = torch.from_numpy(
    compute_class_weight('balanced', np.unique(target.numpy()), target.numpy())
)

weights = cls_weights[target]
sampler = WeightedRandomSampler(weights, len(target), replacement=True)

dataset = TensorDataset(data, target)

batch_size = 64
loader = DataLoader(
    dataset,
    sampler=sampler,
    batch_size=batch_size,
    drop_last=True
)


for x, y in loader:
    for cls in range(3):
        print('Class {}: {}'.format(cls, (y==cls).sum().float() / batch_size))

Could you try it out and compare it to your code?

1 Like