Is there a limit on how disbalanced a train set can be?

Hey there, I have faced an issue with using the WeightedRandomSampler where even with the assigned weights the majority class is being resampled more heavily than the remaining classes. Is there a capacity to the weighted random sampler beyond which resampling doesn’t reweight the samples any more?

If you left the argument replacement´ asTrue`, I assume the limit of the weights is just the floating point precision to represent the weight.
What are the current values you are using?

Hi, my training set is very unbalanced, I mean big time. I use the sklearn.utils.class_weight.compute_class_weight() to get the weights of the classes and get the following coefficients: array([20.98221801, 21.80828959, 20.86804102, 0.25916362]). So we have 3 equally weighted classes and 1 very overwhelming.

As far as I understand, there is no need to normalize the weights for the WeightedRandomSampler():

# define the array of weights same shape as data
weights = np.empty_like(train_targets, dtype=np.float32)

# project the weights along the training set
for i in range(weights.shape[0]):
    weights[i] = class_weights[train_targets[i]]
weights = torch.from_numpy(weights)

Then I am trying to return indices only from the dataloader, so I have:

data = Data()

weighted_sampler = WeightedRandomSampler(weights=weights, num_samples=len(train_targets), replacement=True)

# define dataloader
dataloader = DataLoader(data, sampler=weighted_sampler, batch_size=64, drop_last=True)

However, the dataloader is still heavily unbalanced:

for index in dataloader:
    print(index.numpy())
    print(targets[index])
    break

image

Thanks for the code!
I’m not familiar with sklearn’s compute_clas_weight method, but from the docs it looks like the inverse class frequencies are returned, which should be alright.
It seems your classes 0 to 2 are the minority classes while you have a lot of samples of class3. Is that correct?

Skimming through your code I cannot find any obvious errors. Could you tell me your PyTorch version so that I could have a look later why it’s not working?

Yes, that is correct. Here is the pytorch version:
pytorch-cpu 0.4.0 py36_cpuhe774522_1 pytorch.

1 Like

I used 0.4.0 and tried to stay as close to your code as possible.
This code works on my machine:

data = torch.randn(1000, 1)
target = torch.cat((
    torch.zeros(998),
    torch.ones(1),
    torch.ones(1)*2
)).long()

cls_weights = torch.from_numpy(
    compute_class_weight('balanced', np.unique(target.numpy()), target.numpy())
)

weights = cls_weights[target]
sampler = WeightedRandomSampler(weights, len(target), replacement=True)

dataset = TensorDataset(data, target)

batch_size = 64
loader = DataLoader(
    dataset,
    sampler=sampler,
    batch_size=batch_size,
    drop_last=True
)


for x, y in loader:
    for cls in range(3):
        print('Class {}: {}'.format(cls, (y==cls).sum().float() / batch_size))

Could you try it out and compare it to your code?

1 Like

Ok, I found what is different; I was using the Dataset derivative as follows:

# dataset class for the data
class Data(Dataset):
    def __init__(self):
        pass
        
    def __len__(self):
        pass
    
    def __getitem__(self, idx):
        return idx

I was expecting it to return the indices I could use to subset the training set, however the weighting didn’t work. It DOES work with the TensorDataset, so I will use it instead. Thank you for your help!

Good to hear it’s working now!
However, your approach should also work, as the indices are sampled to select the target classes.
You could replace the TensorDataset with this dummy dataset:

class MyDataset(Dataset):
    def __init__(self, data, target):
        self.data = data
        self.target = target
        
    def __getitem__(self, index):
        
        return index
    
    def __len__(self):
        return len(self.data)

, and you will see, that the indices of 998 and 999 are repeated very often.
Could you provide a code snippet so that I could have another look and make sure it’s not a bug for a special edge case?

So here is the thing: I defined the weights for the train_targets, however I used the sampled indices to pull the targets which is a super set of the train_targets, hence no exception was raised and it went under the radar.

If I define the weights for the full targets set and sample the indices from there as well everything works like it should be. Thank you for your time and help!

1 Like