Hi all,
I want to enlarge my dataset and balance classes at the same time. I’m using random transforms to ensure new images at each batch and WeightedRandomSampler to try to balance the classes.
However, I don’t know how to check if classes are balance or not over the entire data set, since the sampler is acting over each batch. Does anyone have dealt with this problem? I’m new in pytorch. Thanks!
This is my code, I check the number of instances for each class, before the sampler and after the sampler, and didn’t get the balance. Any suggestion? What I’m doing wrong?
transform = {‘train’:transforms.Compose([transforms.Resize((32,32)),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(10),
transforms.RandomAffine(0,shear = 10, scale = (0.8, 1.2)),
transforms.ColorJitter(brightness = 0.2, contrast = 0.2, saturation = 0.2),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
‘val’: transforms.Compose([transforms.Resize((32,32)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
}
image_data = torchvision.datasets.ImageFolder(data_path,transform[‘train’])
image_loader=torch.utils.data.DataLoader(image_data, batch_size=batch_size_t, shuffle=True)
Data set with tranforms
labels = []
for image, label in image_loader:
for item in label.numpy():
labels.append(class_names[item])
labels=np.array(labels)
Check the number of instances for each class
for i in range(0,len(class_names)):
print('Instances in class ‘+ str(class_names[i])+’: ’ + str(len(np.where(labels == class_names[i])[0])))
Calculate the samples weights for the WeightedRandomSampler
weight = np.arange(0,len(class_names)).astype(np.float32)
for j in range(0,len(class_names)):
weight[j] = 1/len(np.where(labels == class_names[j])[0])
samples_weight = []
for image,label in image_loader :
for item in label.numpy():
samples_weight.append(weight[item])
samples_weight = np.array(samples_weight)
samples_weight = torch.from_numpy(samples_weight)
samples_weigth = samples_weight.double()
sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
image_data = torchvision.datasets.ImageFolder(data_path,transform[‘train’])
image_loader = DataLoader(image_data, batch_size=batch_size_t, num_workers=4, sampler=sampler)
labels = []
for image, label in image_loader:
for item in label.numpy():
labels.append(class_names[item])
labels=np.array(labels)
Check the number of instances for each class after sampler
print(’---------------------------------------------------’)
for i in range(0,len(class_names)):
print('Instances in class ‘+ str(class_names[i])+’: ’ + str(len(np.where(labels == class_names[i])[0])))