Hi, I have a highly imbalanced image dataset and would like to use the WeightedRandomSampler to sample from my dataset such that the model sees approximately each class the same number of times. However since my classes are sicknesses present in some images, I have three classes for three sicknesses and a class for no sickness. My labels look like this: [[1,1,0], [0,0,0], [1, 0, 0], …].
How can I use the weighted sampler in this case? If I understood correctly I would need to give 3 weights for the three classes, but that would not take into account the images with label [0,0,0], which are the majority.
Below is a minimal working example using the weighted sampler, but like this none of the three classes gets selected:
import torch
import pandas as pd
import numpy as np
from torch.utils.data import Dataset
from sklearn.utils import shuffle
from torch.utils.data import DataLoader
len_ds = 60461
counts = {'A': 6775, 'B': 3609, 'C': 906}
def create_data(which_class: str):
arr = np.zeros((1, len_ds))
arr[:, :counts[which_class]] = 1
return arr
data = {k: create_data(k) for k in counts}
data = np.concatenate((data['A'], data['B'], data['C']), axis=0)
label_df = pd.DataFrame(data).transpose()
label_df.columns = [*counts.keys()]
label_df = shuffle(label_df)
label_df.reset_index(inplace=True, drop=True)
print(label_df[label_df == 1].count())
class toy_dataset(Dataset):
def __getitem__(self, index):
return torch.tensor(label_df.iloc[index].values)
def __len__(self):
return len(label_df)
dataset = toy_dataset()
def calculateWeights(label_dict, d_set):
arr = []
for label, count in label_dict.items():
weight = len(d_set) / count
arr.append(weight)
return arr
weights = calculateWeights(counts, dataset)
weights = torch.DoubleTensor(weights)
print('weights: ', weights)
sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, len(dataset), replacement=True)
trainloader = DataLoader(dataset, batch_size=50, sampler=sampler,
shuffle=False,
num_workers=0, pin_memory=False)
all_labels = torch.Tensor()
for labels in trainloader:
all_labels = torch.cat((all_labels, labels.cpu()), 0)
print(all_labels.shape)
print(
{label: all_labels[:, i].sum().item() for i, label in
enumerate([*counts.keys()])})
which returns:
torch.Size([60461, 3])
{'A': 6775.0, 'B': 3609.0, 'C': 906.0}
weights: tensor([ 8.9241, 16.7528, 66.7340], dtype=torch.float64)
torch.Size([60461, 3])
{'A': 0.0, 'B': 0.0, 'C': 0.0}