Balanced Sampling between classes with torchvision DataLoader

Ok, I found some other issues. I changed your code a bit. This should work:

numDataPoints = 1000
data_dim = 5
bs = 100

data = torch.FloatTensor(numDataPoints, data_dim)
target = np.hstack((np.zeros(int(numDataPoints * 0.9), dtype=np.int32),
                    np.ones(int(numDataPoints * 0.1), dtype=np.int32)))

dataset_x = data.numpy()
dataset_y = target
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(dataset_x,
                                                    dataset_y,
                                                    test_size=0.33,
                                                    random_state=42,
                                                    stratify = dataset_y)
print 'target train: {}/{}'.format(len(np.where(y_train==0)[0]),
                                   len(np.where(y_train==1)[0]))

class_sample_count = np.array([len(np.where(y_train==t)[0]) for t in np.unique(y_train)])
weight = 1. / class_sample_count
samples_weight = np.array([weight[t] for t in y_train])

samples_weight = torch.from_numpy(samples_weight)
sampler = WeightedRandomSampler(samples_weight.type('torch.DoubleTensor'), len(samples_weight))

trainDataset = torch.utils.data.TensorDataset(torch.FloatTensor(X_train), torch.LongTensor(y_train.astype(int)))
validDataset = torch.utils.data.TensorDataset(torch.FloatTensor(X_test), torch.LongTensor(y_test.astype(int)))

trainLoader = torch.utils.data.DataLoader(dataset = trainDataset, batch_size=bs, num_workers=1, sampler = sampler)
testLoader = torch.utils.data.DataLoader(dataset = validDataset, batch_size=bs, shuffle=False, num_workers=1) 

for i, (data, target) in enumerate(trainLoader):
    print "batch index {}, 0/1: {}/{}".format(
        i, len(np.where(target.numpy()==0)[0]), len(np.where(target.numpy()==1)[0]))

It seems that weights should have the same length as your number of samples.
I created a dummy data set with a target imbalance of 9 to 1.
The for loop should loop through all your train samples with each batch containing approx. the same amount of
zeros and ones.

I hope it works for you!

11 Likes

Thanks for your great response! But one more question about your code. Could please explain how can I determine how many times does For execute? :slight_smile:

No problem :wink:
Sure, you can call len(trainLoader), which will return 7 for the example. In this particular case, it will return 6 full batches of 100 samples and a final one with only 70 samples.

2 Likes

Update: With version pytorch 0.2, sampler is mutually exclusive with shuffle.

4 Likes

By default the WeightedRandomSampler will use replacement=True. In which case, the samples that would be in a batch would not necessarily be unique.

Also, shouldn’t the function call to be like sampler = WeightedRandomSampler(samples_weight.type('torch.DoubleTensor'), batch_size) ?

You are right! Since replacement=True is the default option, there will probably be repetitions.

If you use the batch_size for the num_samples argument, your DataLoader will just return one batch of samples in the loop.

1 Like

I was able to use this weighted random data sampling using @ptrblck 's code.

here is a snippet:

   print('Loading data...')
   train_dataset = datasets.ImageFolder(root=traindir,
                                    transform=transform_train)
   test_dataset = datasets.ImageFolder(root=valdir,
                                   transform=transform_test)
   print('Loading is Done!')

   num_classes = len(train_dataset.classes)

   #TODO: read class counts from the file
   class_sample_counts = [10647, 5659, 31445, 40283,  800,  407, 1111, 22396,  610, 1288, 5708, 1538, 1848, 26015, 17639, 3859,  473, 2509,  579, 2636,  822, 1616, 1226,  949, 1725, 1306, 1758, 1704, 10637, 1091, 1036, 1292,  474,  569, 1682,  553,  506, 7571, 3598, 2280, 24291, 5725, 1319,  824, 5456, 1781, 4074, 2538, 5032,  503, 1623, 7251,  599, 9037, 12221, 2128, 2290,  459, 1549, 1739, 2297,  838,  469,  674, 1030,  994,  704,  672, 1690, 2442,  766,  578, 2032,  534,  552, 13934, 1138, 1372]

   # compute weight for all the samples in the dataset
   # samples_weights contain the probability for each example in dataset to be sampled  
   class_weights = 1./torch.Tensor(class_sample_counts)
   train_targets = [sample[1] for sample in train_dataset.imgs]
   train_samples_weight = [class_weights[class_id] for class_id in train_targets]
   test_targets = [sample[1] for sample in test_dataset.imgs]
   test_samples_weight = [class_weights[class_id] for class_id in test_targets]

   # now lets initialize samplers 
   train_sampler = torch.utils.data.sampler.WeightedRandomSampler(train_samples_weight, len(train_dataset))
   test_sampler = torch.utils.data.sampler.WeightedRandomSampler(test_samples_weight, len(test_dataset))

   train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size,sampler=train_sampler, **kwargs)
   val_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, sampler=test_sampler, **kwargs)

10 Likes

@ptrblck I have a question:

with plain sampler I needed 10GB gpu memory for the batch size of 64
However, with WeightedRandomSampler same batch size is using only 2.3GB gpu memory.

does it has to do with repetitions? In fact, my data is heavily imbalanced.
180611_165227_classification_1_Item_jumabek
Or should I be concerned?

Thanks!

The sampling shouldn’t influence the GPU memory.
I assume your data is on the CPU, you sample from it, and push it inside the training loop to the GPU.
Therefore, your training procedure does not know anything about the sampling procedure.
Your issue sounds quite strange.

Did you make sure the batch size stays the same?

I confirmed that every iteration is using the same batch size compared to normal sampler case.
and yes, I am moving input to gpu.

Would you please help me identify the problem. I am pasting training code here:

    for i, (input, target) in enumerate(train_loader):
        print('#images loaded: {} '.format(target.size()))
        
        if args.cuda and torch.cuda.is_available():
            target = target.cuda(async=True)
            input = input.cuda(async=True)

        # compute output
        output = model(input)
        loss = criterion(output, target)

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

Issue has to do with frozen model weights.
For the same batch size:

  1. 10GB gpu memory required for training all resnet101 layers
  2. 2.3GB gpu memory required for training only layer4 and fc layers of resnet101 . That’s because it does not need to keep track of gradients for layer1,layer2,layer3 of resnet101.

Hi, I have a very very unbalanced dataset as yours, were you able to improve results with this weighted sampler?

Shouldn’t it be weight_per_class[i] = 1/float(count[i]) instead of weight_per_class[i] = N/float(count[i]) ?

Let’s say I have a minibatch size of 16, then how will the dataloader use the WeightedRandomSampler, because this sampler when called with “iter” function is returning an iterator for a list of len(weights), but I need only 16 how are things working out?

this might be relevant :

2 Likes

Hey @InnovArul Did you try to implement ImbalancedDatasetSampler, I could not find anything about it in torch documentation

@ptrblck

If I keep the shuffle = True in train loader, I am getting the below error

raise ValueError('sampler option is mutually exclusive with ’
ValueError: sampler option is mutually exclusive with shuffle

If you provide a sampler to the DataLoader, you cannot specify shuffle anymore, as the sampler is not responsible for creating the indices and thus shuffling.
There are a few samplers, which enable shuffling, e.g. SubsetRandomSampler, WeightedRandomSampler.

4 Likes

@ptrblck I want to clarify one point regarding the WeightedRandomSampler.
While it is oversampling the minority class it is also undersampling the majority class .
Lets say i have 100 images of classA and 900 images of classB
Then dataloader length will be 1000. and when we will iterate in minibatches it will ensure equal distribution thus approx 500 images of class A and 500 images of classB will be used for training.
Can’t we say it is oversampling the minority but undersampling the majority in dataset?

1 Like

You could assume this, if you use the described setup.

However, you could e.g. specify replacement=False, which will return unique num_samples.
The over/undersampling also depends on the specified weights, i.e. the WeightedRandomSampler does not automatically produce equal class distributions in each batch, but you are free to specify the weights you need.