I need to implement a multi-label image classification model in PyTorch. However my data is not balanced, so I used the WeightedRandomSampler
in PyTorch to create a custom dataloader. But when I iterate through the custom dataloader, I get the error : IndexError: list index out of range
Balanced Sampling between classes with torchvision DataLoader (Implemented the following code using this link)
def make_weights_for_balanced_classes(images, nclasses):
count = [0] * nclasses
for item in images:
count[item[1]] += 1
weight_per_class = [0.] * nclasses
N = float(sum(count))
for i in range(nclasses):
weight_per_class[i] = N/float(count[i])
weight = [0] * len(images)
for idx, val in enumerate(images):
weight[idx] = weight_per_class[val[1]]
return weight
weights = make_weights_for_balanced_classes(train_dataset.imgs, len(full_dataset.classes))
weights = torch.DoubleTensor(weights)
sampler = WeightedRandomSampler(weights, len(weights))
train_loader = DataLoader(train_dataset, batch_size=4,sampler = sampler, pin_memory=True)
Thanks a lot in advance !