Dear all,
I have an imbalanced dataset for training which I have used WeightedRandomSampler for train, dev_train and test separately:
class_weights=[1,112/75,112/32] # class 0 (112), class 1 (75) class 2 (32)
train_sample_weights=[0]*len(datasets_list['train']) #172
dev_train_sample_weights=[0]*len(datasets_list['dev_train']) #21
test_sample_weights=[0]*len(datasets_list['test']) #26
from torch.utils.data import WeightedRandomSampler
for idx, data in enumerate(datasets_list['train']):
label=data['label']
class_weight=class_weights[label]
train_sample_weights[idx]=class_weight
sampler_train=WeightedRandomSampler(train_sample_weights,num_samples=len(train_sample_weights),replacement=True)
for idx, data in enumerate(datasets_list['dev_train']):
label=data['label']
class_weight=class_weights[label]
dev_train_sample_weights[idx]=class_weight
sampler_dev_train=WeightedRandomSampler(dev_train_sample_weights,num_samples=len(dev_train_sample_weights),replacement=True)
for idx, data in enumerate(datasets_list['test']):
label=data['label']
class_weight=class_weights[label]
test_sample_weights[idx]=class_weight
sampler_test=WeightedRandomSampler(test_sample_weights,num_samples=len(test_sample_weights),replacement=True)
datasets_list={}
for group in dict_expl_MCclassifier_training.keys(): #created with all the data from 3 classes
if group =='train':
datasets_list[group] = tio.SubjectsDataset(dict_expl_MCclassifier_training[group], transform=transform_train)
else:
datasets_list[group] = tio.SubjectsDataset(dict_expl_MCclassifier_training[group], transform=transform_dev)
dataloaders_glaucoma_training_MCclassifier={'train': DataLoader(dataset=datasets_list['train'], batch_size=10,sampler= sampler_train, num_workers=8),
'dev_train': DataLoader(dataset=datasets_list['dev_train'], batch_size=10,sampler=sampler_dev_train, num_workers=8),
'test': DataLoader(dataset=datasets_list['test'], batch_size=10,sampler=sampler_test, num_workers=8)}
The problem is after testing the trained model, I can see data from class 1 and 2 are misclassified as class 0(the largest class).
So apparently the weightedsampler has not worked properly but how can I make sure if the weightedsampler is the problem?
p.s.:when I manually and randomly selected similar number of data from 3 classes the classifier is working much better