Balanced Sampling between classes with torchvision DataLoader

Hi all,

I’m trying to find a way to make a balanced sampling using ImageFolder and DataLoader with a imbalanced dataset. I suppose that I should build a new sampler.

I’m not sure if I’m missing something. Is there an already implemented way of do it?

Thanks

Code:

train_loader = torch.utils.data.DataLoader(
datasets.ImageFolder(traindir, transforms.Compose([
transforms.Scale(600),
transforms.RandomSizedCrop(512),
transforms.ToTensor(),
normalize
])),
batch_size=args.batch_size, shuffle=True,
num_workers=args.workers, pin_memory=True)

10 Likes

you should look at How to Prevent Overfitting

1 Like

For a dataset coming from a torchvision ImageFolder dataset the code that I’ve finally used, based on the suggestion of @smth is the next:

First, define the function:

def make_weights_for_balanced_classes(images, nclasses):                        
    count = [0] * nclasses                                                      
    for item in images:                                                         
        count[item[1]] += 1                                                     
    weight_per_class = [0.] * nclasses                                      
    N = float(sum(count))                                                   
    for i in range(nclasses):                                                   
        weight_per_class[i] = N/float(count[i])                                 
    weight = [0] * len(images)                                              
    for idx, val in enumerate(images):                                          
        weight[idx] = weight_per_class[val[1]]                                  
    return weight                                                               

And after this, use it in the next way:

dataset_train = datasets.ImageFolder(traindir)                                                                         
                                                                                
# For unbalanced dataset we create a weighted sampler                       
weights = make_weights_for_balanced_classes(dataset_train.imgs, len(dataset_train.classes))                                                                
weights = torch.DoubleTensor(weights)                                       
sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, len(weights))                     
                                                                                
train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=args.batch_size, shuffle = True,                              
                                                             sampler = sampler, num_workers=args.workers, pin_memory=True)     
61 Likes

Super helpful - thank you!

Could you please tell me how did you iterate over train_loader? I mean the for LOOP.

1 Like

You can just loop over it:

for data, target in train_loader:
    data = ...

Hi @ptrblck, Thanks for your response! My code is like following:

data = torch.FloatTensor(numClass, numDataPoints, data_dim)
target = torch.zeros(numClass, numDataPoints)
weight = torch.histc(target, bins=10, min = 0, max = 9)
weight = 1/(weight + 1e-3)
sampler = torch.utils.data.sampler.WeightedRandomSampler(weight.type('torch.DoubleTensor'), bs)
data = data.view(-1,data_dim)
target = target.view(-1,1)
dataset_x = data.numpy()
dataset_y = target.numpy().astype(int)
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(dataset_x,
                                                    dataset_y,
                                                    test_size=0.33,
                                                    random_state=42,
                                                    stratify = dataset_y)
trainDataset = torch.utils.data.TensorDataset(torch.FloatTensor(X_train), torch.LongTensor(y_train.astype(int)))
validDataset = torch.utils.data.TensorDataset(torch.FloatTensor(X_test), torch.LongTensor(y_test.astype(int)))

trainLoader = torch.utils.data.DataLoader(dataset = trainDataset, batch_size=bs, num_workers=1, sampler = sampler)
testLoader = torch.utils.data.DataLoader(dataset = validDataset, batch_size=bs, shuffle=False, num_workers=1) 

As you have mentioned above, I have written a code:

for data, labels in train_loader:
         # print something

But something went wrong here. It gave me an error like this:

‘StopIteration’

Could you please tell me what is wrong with my code? Actually, I would like to get samples from my data loader 1000 times and then at each step train the network and do backward!

In your code the weights seen to be invalid, since you initialize them with an all zero target.
The histc method will just return a histogram with all samples being 0.
Maybe this is your error?

Here is a sample snippet:

batch_size = 20
class_sample_count = [10, 1, 20, 3, 4] # dataset has 10 class-1 samples, 1 class-2 samples, etc.
weights = 1 / torch.Tensor(class_sample_count)
weights = weights.double()
sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, batch_size)
trainloader = torch.utils.data.DataLoader(trainDataset, batch_size = batch_size, sampler = sampler)
5 Likes

Sorry, @ptrblck, please assume that the target is initialized with some random integer numbers between 0 and 9. In practice, they are not all zero. So, with this clarification, what is wrong with the code?

Ok, I found some other issues. I changed your code a bit. This should work:

numDataPoints = 1000
data_dim = 5
bs = 100

data = torch.FloatTensor(numDataPoints, data_dim)
target = np.hstack((np.zeros(int(numDataPoints * 0.9), dtype=np.int32),
                    np.ones(int(numDataPoints * 0.1), dtype=np.int32)))

dataset_x = data.numpy()
dataset_y = target
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(dataset_x,
                                                    dataset_y,
                                                    test_size=0.33,
                                                    random_state=42,
                                                    stratify = dataset_y)
print 'target train: {}/{}'.format(len(np.where(y_train==0)[0]),
                                   len(np.where(y_train==1)[0]))

class_sample_count = np.array([len(np.where(y_train==t)[0]) for t in np.unique(y_train)])
weight = 1. / class_sample_count
samples_weight = np.array([weight[t] for t in y_train])

samples_weight = torch.from_numpy(samples_weight)
sampler = WeightedRandomSampler(samples_weight.type('torch.DoubleTensor'), len(samples_weight))

trainDataset = torch.utils.data.TensorDataset(torch.FloatTensor(X_train), torch.LongTensor(y_train.astype(int)))
validDataset = torch.utils.data.TensorDataset(torch.FloatTensor(X_test), torch.LongTensor(y_test.astype(int)))

trainLoader = torch.utils.data.DataLoader(dataset = trainDataset, batch_size=bs, num_workers=1, sampler = sampler)
testLoader = torch.utils.data.DataLoader(dataset = validDataset, batch_size=bs, shuffle=False, num_workers=1) 

for i, (data, target) in enumerate(trainLoader):
    print "batch index {}, 0/1: {}/{}".format(
        i, len(np.where(target.numpy()==0)[0]), len(np.where(target.numpy()==1)[0]))

It seems that weights should have the same length as your number of samples.
I created a dummy data set with a target imbalance of 9 to 1.
The for loop should loop through all your train samples with each batch containing approx. the same amount of
zeros and ones.

I hope it works for you!

11 Likes

Thanks for your great response! But one more question about your code. Could please explain how can I determine how many times does For execute? :slight_smile:

No problem :wink:
Sure, you can call len(trainLoader), which will return 7 for the example. In this particular case, it will return 6 full batches of 100 samples and a final one with only 70 samples.

2 Likes

Update: With version pytorch 0.2, sampler is mutually exclusive with shuffle.

4 Likes

By default the WeightedRandomSampler will use replacement=True. In which case, the samples that would be in a batch would not necessarily be unique.

Also, shouldn’t the function call to be like sampler = WeightedRandomSampler(samples_weight.type('torch.DoubleTensor'), batch_size) ?

You are right! Since replacement=True is the default option, there will probably be repetitions.

If you use the batch_size for the num_samples argument, your DataLoader will just return one batch of samples in the loop.

1 Like

I was able to use this weighted random data sampling using @ptrblck 's code.

here is a snippet:

   print('Loading data...')
   train_dataset = datasets.ImageFolder(root=traindir,
                                    transform=transform_train)
   test_dataset = datasets.ImageFolder(root=valdir,
                                   transform=transform_test)
   print('Loading is Done!')

   num_classes = len(train_dataset.classes)

   #TODO: read class counts from the file
   class_sample_counts = [10647, 5659, 31445, 40283,  800,  407, 1111, 22396,  610, 1288, 5708, 1538, 1848, 26015, 17639, 3859,  473, 2509,  579, 2636,  822, 1616, 1226,  949, 1725, 1306, 1758, 1704, 10637, 1091, 1036, 1292,  474,  569, 1682,  553,  506, 7571, 3598, 2280, 24291, 5725, 1319,  824, 5456, 1781, 4074, 2538, 5032,  503, 1623, 7251,  599, 9037, 12221, 2128, 2290,  459, 1549, 1739, 2297,  838,  469,  674, 1030,  994,  704,  672, 1690, 2442,  766,  578, 2032,  534,  552, 13934, 1138, 1372]

   # compute weight for all the samples in the dataset
   # samples_weights contain the probability for each example in dataset to be sampled  
   class_weights = 1./torch.Tensor(class_sample_counts)
   train_targets = [sample[1] for sample in train_dataset.imgs]
   train_samples_weight = [class_weights[class_id] for class_id in train_targets]
   test_targets = [sample[1] for sample in test_dataset.imgs]
   test_samples_weight = [class_weights[class_id] for class_id in test_targets]

   # now lets initialize samplers 
   train_sampler = torch.utils.data.sampler.WeightedRandomSampler(train_samples_weight, len(train_dataset))
   test_sampler = torch.utils.data.sampler.WeightedRandomSampler(test_samples_weight, len(test_dataset))

   train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size,sampler=train_sampler, **kwargs)
   val_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, sampler=test_sampler, **kwargs)

10 Likes

@ptrblck I have a question:

with plain sampler I needed 10GB gpu memory for the batch size of 64
However, with WeightedRandomSampler same batch size is using only 2.3GB gpu memory.

does it has to do with repetitions? In fact, my data is heavily imbalanced.
180611_165227_classification_1_Item_jumabek
Or should I be concerned?

Thanks!

The sampling shouldn’t influence the GPU memory.
I assume your data is on the CPU, you sample from it, and push it inside the training loop to the GPU.
Therefore, your training procedure does not know anything about the sampling procedure.
Your issue sounds quite strange.

Did you make sure the batch size stays the same?

I confirmed that every iteration is using the same batch size compared to normal sampler case.
and yes, I am moving input to gpu.

Would you please help me identify the problem. I am pasting training code here:

    for i, (input, target) in enumerate(train_loader):
        print('#images loaded: {} '.format(target.size()))
        
        if args.cuda and torch.cuda.is_available():
            target = target.cuda(async=True)
            input = input.cuda(async=True)

        # compute output
        output = model(input)
        loss = criterion(output, target)

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

Issue has to do with frozen model weights.
For the same batch size:

  1. 10GB gpu memory required for training all resnet101 layers
  2. 2.3GB gpu memory required for training only layer4 and fc layers of resnet101 . That’s because it does not need to keep track of gradients for layer1,layer2,layer3 of resnet101.