I am struggling with training a classifier, as my dataset is very imbalanced. There are 40 classes, with 200 to 2000 samples in each. The classifier tends to learn the classes with most samples, but has very poor accuracy on the classes with the fewest samples.
I have found this thread about using WeightedRandomSampler, but have not been successful in implementing it in my code.
The relevant code is below.
When running it, it stops at the line with train_targets = [sample[1] for sample in dataset_train.imgs]
…saying:
AttributeError: 'Subset' object has no attribute 'imgs'
My question is two-fold:
- How can I resolve the above error?
- Looking closer at my code, I believe that the problems with the imbalanced dataset actually begin already when the one dataset is split into a training and a validation set. As it is split randomly, the validation set could, for instance, happen to contain all the samples from a certain class, leaving none for the training set. So how can I ensure that the dataset is split reasonably?
def get_datasets():
dataset = datasets.ImageFolder(
ROOT_DIR,
transform = transform_train
)
dataset_test = datasets.ImageFolder(
ROOT_DIR,
transform=transform_valid
)
dataset_size = len(dataset)
valid_size = int(VALID_SPLIT*dataset_size)
# Randomize the data indices:
indices = torch.randperm(len(dataset)).tolist()
# Training and validation sets:
dataset_train = Subset(dataset, indices[:-valid_size])
dataset_valid = Subset(dataset_test, indices[-valid_size:])
return dataset_train, dataset_valid, dataset.classes
def get_data_loaders(dataset_train, dataset_valid):
# Make a sampler to balance the dataset:
class_sample_count = [200, 400, 320, 690, 1120, 1200, 390, 200, 2000, 310,
310, 330, 390, 610, 220, 1550, 600, 230, 530, 750,
200, 1950, 2000, 1410, 1980, 1860, 420, 1440, 1410, 1470,
350, 740, 1150, 1900, 220, 2000, 250, 1980, 1100, 300]
class_weights = 1. / torch.Tensor(class_sample_count)
train_targets = [sample[1] for sample in dataset_train.imgs]
train_samples_weight = [class_weights[class_id] for class_id in train_targets]
sampler = WeightedRandomSampler(train_samples_weight.type('torch.DoubleTensor'), len(samples_weight))
train_sampler = torch.utils.data.sampler.WeightedRandomSampler(train_samples_weight, len(train_dataset))
train_loader = DataLoader(
dataset_train, batch_size=BATCH_SIZE,
shuffle=True, num_workers=NUM_WORKERS,
sampler=train_sampler,
)
valid_loader = DataLoader(
dataset_valid, batch_size=BATCH_SIZE,
shuffle=False, num_workers=NUM_WORKERS,
)
return train_loader, valid_loader
dataset_train, dataset_valid, dataset_classes = get_datasets()
train_loader, test_loader = get_data_loaders(dataset_train, dataset_valid)