Hi there!
I am facing some problem with WeightedRandomSampler since I want to sample uniformly data among several labels.
A really short description of the problem:
The problem concerns plankton image classification. Data is tremendously unbalanced as the majority of images are just “detritus”(which is not even plankton).
In numbers, the most common label is “detritus” with ~100k images in the dataset whether the least has just 12 samples. I want to have a high f1 “Macro” score, so I want to train my models through uniform distribution of labels.
Here I build the data loaders:
# weights for weighted random sampler
X_ = plankton_df['objid']
y_ = plankton_df['level2']
X_train, X_test, y_train, y_test = train_test_split(X_, y_, test_size=0.15, stratify=y_)
weights_df = plankton_df[plankton_df['objid'].isin(X_train.values)] # assign weights only to samples in the training set
weights_df['level2'] = weights_df['level2'].map(label_mapping)
weights_df['count'] = weights_df.groupby('level2')['level2'].transform('count')
weights_df['count'] = 1. / (weights_df['count'])
weights = weights_df['count'].values
weights = torch.DoubleTensor(weights)
detritus_wrs = torch.utils.data.sampler.WeightedRandomSampler(weights, len(weights), replacement=True)
# Extract images in memory
plank_train_dataset = PlanktonDataset(X_train.values, y_train.values,
img_files = img_files,
label_mapping = label_mapping,
transform = training_transformations,
)
plank_train_dataloader = DataLoader(plank_train_dataset, batch_size=64, sampler=detritus_wrs, num_workers=4)
with the line of code “weights_df[‘count’] = 1. / (weights_df[‘count’])” I just want to make sure that for each sample it is assigned the inverse of the cardinality of his label.
But if I run this cell
minibatch = next(iter(plank_train_dataloader))
print(minibatch['label'])
I get as output (labels are not uniform at all):
tensor([ 1, 0, 2, 4, 0, 0, 0, 2, 0, 15, 1, 3, 0, 0, 20, 11, 2, 0,
0, 0, 0, 0, 1, 0, 0, 0, 2, 3, 1, 1, 2, 0, 12, 0, 0, 0,
0, 15, 26, 0, 21, 0, 0, 27, 9, 2, 0, 3, 0, 0, 0, 2, 0, 0,
0, 16, 0, 3, 0, 0, 1, 4, 0, 0])
Do you have any idea? It is more likely that there are some mistakes in my code rather than bugs on the library
I tried to be as more concise as possible, but If you need other pieces of code or other information of the problem just write here.
Thanks in advance for your consideration