Simoultaneous data augmentation and class balancing?

Hi all,
I want to enlarge my dataset and balance classes at the same time. I’m using random transforms to ensure new images at each batch and WeightedRandomSampler to try to balance the classes.
However, I don’t know how to check if classes are balance or not over the entire data set, since the sampler is acting over each batch. Does anyone have dealt with this problem? I’m new in pytorch. Thanks!
This is my code, I check the number of instances for each class, before the sampler and after the sampler, and didn’t get the balance. Any suggestion? What I’m doing wrong?
transform = {‘train’:transforms.Compose([transforms.Resize((32,32)),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(10),
transforms.RandomAffine(0,shear = 10, scale = (0.8, 1.2)),
transforms.ColorJitter(brightness = 0.2, contrast = 0.2, saturation = 0.2),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
‘val’: transforms.Compose([transforms.Resize((32,32)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
}

image_data = torchvision.datasets.ImageFolder(data_path,transform[‘train’])

image_loader=torch.utils.data.DataLoader(image_data, batch_size=batch_size_t, shuffle=True)

Data set with tranforms

labels = []
for image, label in image_loader:
for item in label.numpy():
labels.append(class_names[item])

labels=np.array(labels)

Check the number of instances for each class

for i in range(0,len(class_names)):
print('Instances in class ‘+ str(class_names[i])+’: ’ + str(len(np.where(labels == class_names[i])[0])))

Calculate the samples weights for the WeightedRandomSampler

weight = np.arange(0,len(class_names)).astype(np.float32)

for j in range(0,len(class_names)):
weight[j] = 1/len(np.where(labels == class_names[j])[0])

samples_weight = []

for image,label in image_loader :
for item in label.numpy():
samples_weight.append(weight[item])

samples_weight = np.array(samples_weight)

samples_weight = torch.from_numpy(samples_weight)
samples_weigth = samples_weight.double()
sampler = WeightedRandomSampler(samples_weight, len(samples_weight))

image_data = torchvision.datasets.ImageFolder(data_path,transform[‘train’])
image_loader = DataLoader(image_data, batch_size=batch_size_t, num_workers=4, sampler=sampler)

labels = []
for image, label in image_loader:
for item in label.numpy():
labels.append(class_names[item])

labels=np.array(labels)

Check the number of instances for each class after sampler

print(’---------------------------------------------------’)
for i in range(0,len(class_names)):
print('Instances in class ‘+ str(class_names[i])+’: ’ + str(len(np.where(labels == class_names[i])[0])))

I think the check you implemented should work (though using np.unique probably would be more efficient than handling each class with np.where). Where are you experiencing difficulties with it?

Best regards

Thomas

P.S.: If you enclose your code in triple backticks ``` , it’ll keep the formatting.

Hi Tom, thanks for your suggestion. My main problem is that I don’t know how to check the number of instances for each class is balanced. I’ve tried to follow some approaches of this forum, but I think, classes are still unbalanced. Thanks.

But what do you get as output when you run your code?

The output consists on the number of instances that I have for each class before and after the sampler. Numbers are almost the same, I just get slightly differences. Am I wrongly counting the instances?

Output:

Instances in class 0: 127
Instances in class 1: 131
Instances in class 2: 38
Instances in class 3: 154
Instances in class 4: 49
Instances in class 5: 130
Instances in class 6: 264
Instances in class 7: 44
Instances in class 8: 68
Instances in class 9: 242
Instances in class 10: 64
Instances in class 11: 40


Instances in class 0: 120
Instances in class 1: 132
Instances in class 2: 35
Instances in class 3: 174
Instances in class 4: 50
Instances in class 5: 133
Instances in class 6: 285
Instances in class 7: 29
Instances in class 8: 66
Instances in class 9: 213
Instances in class 10: 71
Instances in class 11: 43

Ah, sorry. I think you want replacement=True (and set some numberof samples) in the sampler.

Thanks Tom, I followed your suggestion, but apparently classes remain unbalanced. I can’t understand well how weights work in sampler, because I’ve tried with different approaches and I can’t control the resulting number of instances.