How to implement oversampling in Cifar-10?

Hi,

I need to train a convolution network using some oversampling technique in the Cifar-10 database. But I do not know how to do it in Pytorch. First I need to simulate the problem of class imbalance at the dataset, because CIFAR-10 is a balanced dataset. And then apply some oversampling technique. Could someone give me an example?

1 Like

Would it work for you if you just oversample a specific class?


train_dataset = torchvision.datasets.CIFAR10(root='YOUR_PATH,
                                             transform=torchvision.transforms.ToTensor())
target = train_dataset.train_labels
class_sample_count = np.unique(target, return_counts=True)[1]
print(class_sample_count)

# oversample class 0
class_sample_count[0] = 50

weight = 1. / class_sample_count
samples_weight = weight[target]
samples_weight = torch.from_numpy(samples_weight)
sampler = WeightedRandomSampler(samples_weight, len(samples_weight))

train_loader = DataLoader(
    train_dataset, batch_size=10, num_workers=1, sampler=sampler)

for i, (data, target) in enumerate(train_loader):
    print "batch index {}, class 0 {}, other classes {}".format(
        i,
        len(np.where(target.numpy() == 0)[0]),
        len(np.where(target.numpy != 0)[0]))
2 Likes

Thank you @ptrblck.

Does this code unbalance just one class? Would I need to adjust any parameters in the training?

Sorry, I might misunderstood your question.
First you would like to simulate an imbalanced dataset and later oversample such that the batch contains equal number of classes?

Yes, that’s right. I want to simulate a dataset of real world, since in the real world the classes are unbalanced. However, I need to adjust the network to learn the unbalanced classes.

I have tested your example code above and the accuracy was very low.

Ah ok, sorry for the misunderstanding.

You could do the following.
First, let’s create an imbalanced CIFAR10 dataset. In the original CIFAR10 dataset each class has 5000 instances.
For simplicity let’s just use 500 instances of class0, 5000 instances of class1, 500 instance of class2, …

# Load CIFAR10
dataset = datasets.CIFAR10(
    root='YOUR_PATH,
    transform=transforms.ToTensor())

# Get all training targets and count the number of class instances
targets = np.array(dataset.train_labels)
classes, class_counts = np.unique(targets, return_counts=True)
nb_classes = len(classes)
print(class_counts)

# Create artificial imbalanced class counts
imbal_class_counts = [500, 5000] * 5

# Get class indices
class_indices = [np.where(targets == i)[0] for i in range(nb_classes)]

# Get imbalanced number of instances
imbal_class_indices = [class_idx[:class_count] for class_idx, class_count in zip(class_indices, imbal_class_counts)]
imbal_class_indices = np.hstack(imbal_class_indices)

# Set target and data to dataset
dataset.train_labels = targets[imbal_class_indices]
dataset.train_data = dataset.train_data[imbal_class_indices]

assert len(dataset.train_labels) == len(dataset.train_data)

Now that we have thrown out a lot of samples, let’s have a look at the training loop.

loader = DataLoader(
    dataset, batch_size=64, shuffle=True)

# Here we have an imbalanced dataset
for batch_idx, (data, target) in enumerate(loader):
    print('Batch {}, classes {}, count {}'.format(
        batch_idx, *np.unique(target.numpy(), return_counts=True)))
    # Your model will most likely perform bad, and will overfit on the 
    # majority classes

In this loop, the samples are imbalanced and your model will most likely overfit on the majority classes.
You can also resample in a more imbalanced way to see the effect more clearly.
Try to train your model here and have a look at the performance.

To counter the imbalanced dataset, let’s create a WeightedRandomSampler, which draws the samples using probabilities (weights).

# Oversample the minority classes
targets = dataset.train_labels
class_count = np.unique(targets, return_counts=True)[1]
print(class_count)

weight = 1. / class_count
samples_weight = weight[targets]
samples_weight = torch.from_numpy(samples_weight)
sampler = WeightedRandomSampler(samples_weight, len(samples_weight))

Now we can create a training loop again and have a look at the sample distribution.

weighted_loader = DataLoader(
    dataset, batch_size=64, sampler=sampler)

# Here we have a balanced dataset due to oversampling
for batch_idx, (data, target) in enumerate(weighted_loader):
    print('Batch {}, classes {}, count {}'.format(
        batch_idx, *np.unique(target.numpy(), return_counts=True)))
    # Your model will propably perform better here

Now the samples are drawn in a more uniform way and your model will probably perform better now.
Try to train your model again and compare the accuracies.

Let me know, if you can use this example!

6 Likes

Thank you so much for help me. I’ll test your solution and when I have results I get back to you.

Hi @ptrblck,

Sorry for the delay. I have trained the model with its oversampling example, however the accuracy of the model have been very low =/

Did you resample the test/validation set as well?
How was the training accuracy?

I resample the test set in the same way as training, I change the imbal_class_counts to imbal_class_counts = [100, 1000] * 5. The test accuracy was 25% after 30 epochs and the training loss was close to zero.

I’ll try to test the over_sampling.SMOTE function from the imbalanced-learn library to see if the accuracy remains low.

Hi @ptrblck,

There was an error in my test function, so the accuracy was getting low. Now the accuracy is better, not as good as the result with the balanced database, but it is acceptable. Sorry for the mistake and thank you very much for the help.

That’s good to hear! I was checking it myself, since I wanted to have a look at the accuracy.
Would you mind sharing the accuracies you have achieved using the balanced and imbalanced datasets?

I have not finished training yet, but on a balanced dataset I got 93% accuracy and the unbalanced database it’s at 81%, but I’m still running the pre-training. When I finish training, I’ll put the final value here. Thanks again.

Hi @ptrblck,

I achieved 84% accuracy in the final model with the unbalanced dataset. It’s a lot smaller than the value I got with the balanced dataset (93%), so I’m trying another oversampling strategy, however I’m having a problem. Since the function I am using to do oversampling is from another library, I have to make some modifications to the data and them to use the function, like this:

data, label = SMOTE().fit_sample(d2_train_data, dataset.train_labels)

where d2_train_data is the training data transformed into a 2D-array to use the SMOTE function.
Then I return the data to the original format and load the base using the Dataloader.

dataset.train_data = torch.from_numpy(dataset.train_data)
dataset.train_labels = torch.from_numpy(dataset.train_labels)

dataset = torch.utils.data.TensorDataset(dataset.train_data, dataset.train_labels)

loader = DataLoader( dataset=dataset, batch_size=64, shuffle=True)

However in the end is giving error related to the weights of the network.

RuntimeError: Given groups=1, weight[64, 3, 7, 7], so expected input[64, 32, 32, 3] to have 3 channels, but got 32 channels instead

I think it’s due to the transform that I can not use after this modification. Am I doing something wrong?

1 Like

That’s good news! Your SMOTE approach sounds interesting.

Your error is most likely due to the image loading.
It seems the channels are in dimension 3, while PyTorch needs them in dimension 1.
Try to permute your images where you are loading them:

image = image.permute(0, 3, 1, 2).contiguous()

Let me know, if it helps!

Thanks. I’ll test and return you later.

Josiane could you please provide more details on how do you use SMOTE on images, if I understand correctly.
Thanks.

Hi @vfdev-5,
Do you want to understand how SMOTE works, or how I implemented it in my code? I have not yet tested the solution that @ptrblck suggested for the error I reported above, I am trying to solve another problem. But I will come back to this problem as soon as possible.

@Josiane_Rodrigues I wanted to understand how SMOTE works on images and if it really makes sens to do it on images (i.e SMOTE on images = blending of images) ?