DataLoader - using SubsetRandomSampler and WeightedRandomSampler at the same time

I have a dataset that contains both the training and validation set. I am aware that I can use the SubsetRandomSampler to split the dataset into the training and validation subsets. The dataset however, has an unbalanced class ratio. How can I also use the WeightedRandomSampler together with the SubsetRandomSampler ? Below is what I currently have using only the SubsetRandomSampler.

    # build the dataset
    dset = TrajectoryDataset(
        path,
        obs_len=args.obs_len,
        skip=args.skip)
    
    # get the class sample counts
    class_sample_count = [dset.positive_sample_count(), dset.negative_sample_count())
  
    # split to train val
    validation_split = 0.8
    dataset_size = len(dset)
    indices = list(range(dataset_size))
    split = int(np.floor(validation_split * dataset_size))
    if 1 :
        np.random.seed(1337)
        np.random.shuffle(indices)
    train_indices, valid_indices = indices[split:], indices[:split]
    train_sampler = SubsetRandomSampler(train_indices)
    valid_sampler = SubsetRandomSampler(valid_indices)

    train_loader = DataLoader(
        dset,
        batch_size=32,#args.batch_size,
        num_workers=2,#args.loader_num_workers,
        collate_fn=seq_collate,
        sampler=train_sampler)

    valid_loader = DataLoader(
        dset,
        batch_size=32,#args.batch_size,
        num_workers=2,#args.loader_num_workers,
        collate_fn=seq_collate,
        sampler=valid_sampler)
4 Likes

That’s an interesting use case!
Basically you could just use the subset indices to create your WeightedRandomSampler, i.e. calculate the class imbalance, weights etc.
Here is a small example:

# Create dummy data with class imbalance 99 to 1
numDataPoints = 1000
data_dim = 5
bs = 100
data = torch.randn(numDataPoints, data_dim)
target = torch.cat((torch.zeros(int(numDataPoints * 0.99), dtype=torch.long),
                    torch.ones(int(numDataPoints * 0.01), dtype=torch.long)))

print('target train 0/1: {}/{}'.format(
    (target == 0).sum(), (target == 1).sum()))

# Create subset indices
subset_idx = torch.cat((torch.arange(100), torch.arange(-5, 0)))

# Compute samples weight (each sample should get its own weight)
class_sample_count = torch.tensor(
    [(target[subset_idx] == t).sum() for t in torch.unique(target, sorted=True)])
weight = 1. / class_sample_count.float()
samples_weight = torch.tensor([weight[t] for t in target[subset_idx]])

# Create sampler, dataset, loader
sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
train_dataset = torch.utils.data.TensorDataset(
    data[subset_idx], target[subset_idx])
train_loader = DataLoader(
    train_dataset, batch_size=bs, num_workers=1, sampler=sampler)

# Iterate DataLoader and check class balance for each batch
for i, (x, y) in enumerate(train_loader):
    print("batch index {}, 0/1: {}/{}".format(
        i, (y == 0).sum(), (y == 1).sum()))

As you can see, I’ve used the code to create a vanilla WeightedRandomSampler and just calculated some subset indices.
Using these indices, the class_sample_count, weight and samples_weight were calculated.
Let me know, if that would work for you.

7 Likes

I am actually still unsure how to use it on my dataset. Now I have the code written that gets the training and validation indices and weights.


# build the dataset
dset = TrajectoryDataset(
    path,
    obs_len=args.obs_len,
    skip=args.skip)

# indices
dset.train_indices # indices of the training set
dset.valid_indices # indices of the validation set
dset.train_weights # sample weights for the training set
dset.valid_weights # sample weights for the validation set

But I am not sure what to pass into the DataLoader function. You created a TensorDataset with data[subset_idx] but I have my own dset because I need my own getitem method. I can’t pass dset[dset.train_indices].

OK, I see.
How are you calculating the train and val indices?
Would it be possible to calculate them before creating the Dataset?
If so, you could create a train and val dataset using the indices, i.e. only images from these indexed paths will be loaded.

If that’s not possible, could you share the code of your Dataset?

1 Like

Thank you ! I decided to create the training and validation indices before creating the Dataset. Much simpler this way.

1 Like

Hey @ptrblck,

I have a question with the training and the validation dataset. I have already separated training, validation and test datasets which is class imbalanced and I am using WeightedRandomSampler to balance it. So, my question is, Is it necessary to apply WeightedRandomSampler to the validation dataset? I am a little bit confused about it. Can you please help me with it?

Thanks,
Vishal

The weighted sampling should only be used for training, to balance the classes in each batch, which hopefully helps the training.

The validation and test accuracy is calculated on the complete datasets without any sampling (shuffle is also not needed here, since the order of the data won’t change the metrics).

2 Likes

Thank you @ptrblck I get it now. Thank you for the help!

How are the weights getting calculated as follows: -

weight = 1. / class_sample_count.float()

As per the above logic, the data points with label 0 (which form 99% of the data) have lower probability of being chosen, right? Please point out to me if I am wrong. I tried looking into WeightedRandomSampler documentation but am still confused. Thanks.

Yes, the higher the count, the lower the weight and tus the probability to sample this particular sample.

1 Like

Hi @ptrblck, Thanks so much for the sample code. I followed the steps that you mentioned. But I am still not getting well distributed classes: The following is the class count. This dataset is well distributed, but still I am not getting good distribution:

Class count:
tensor([1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 945])

batch index 0, 0/1: 9/6/6,9
batch index 1, 0/1: 6/12/12,4
batch index 2, 0/1: 9/8/9,12
batch index 3, 0/1: 10/6/12,5
batch index 4, 0/1: 5/5/9,12

Are you seeing the same effect with a higher batch size?

I wouldn’t be able to increase the batch size more than 64. I have actually printed only the first four classes. Check the following. This is with 64 batch size:

0 [8, 9, 10, 13, 5, 9, 6, 4, 0, 0]
1 [10, 5, 7, 10, 8, 10, 9, 5, 0, 0]
2 [10, 6, 7, 10, 6, 7, 9, 9, 0, 0]
3 [7, 7, 6, 15, 7, 6, 8, 8, 0, 0]
4 [10, 7, 7, 7, 11, 5, 6, 11, 0, 0]
5 [9, 5, 6, 5, 7, 12, 8, 12, 0, 0]
6 [10, 5, 6, 10, 9, 7, 9, 8, 0, 0]
7 [7, 7, 13, 11, 5, 5, 6, 10, 0, 0]
8 [11, 5, 9, 12, 5, 5, 7, 10, 0, 0]
9 [9, 4, 10, 6, 10, 5, 10, 10, 0, 0]
10 [11, 11, 8, 7, 8, 5, 8, 6, 0, 0]
11 [10, 10, 8, 8, 8, 7, 8, 5, 0, 0]

Could you provide the code you are using for your WeightedRandomSampler please?
It seems that no samples from the last two classes are drawn, which might be the case, if their weight is zero.

Here is the code:
class_sample_count = torch.tensor(
[(torch.tensor(train_dataset.targets) == t).sum()
for t in torch.unique(torch.tensor([train_dataset.targets]), sorted=True)])
weight = torch.div(1., class_sample_count.float())
train_samples_weight = torch.tensor([weight[t.item()] for t in torch.tensor(train_dataset.targets)[train_idx]])
train_sampler = torch.utils.data.WeightedRandomSampler(train_samples_weight, len(train_samples_weight))

This is the output of above code
Weights:
tensor([0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010,
0.0011])
Class count:
tensor([1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 945])
train loader:
0 [11, 7, 7, 7, 10, 7, 8, 7, 0, 0]
1 [7, 10, 10, 5, 6, 7, 10, 9, 0, 0]
2 [6, 6, 6, 11, 13, 7, 7, 8, 0, 0]
3 [9, 8, 5, 9, 10, 7, 4, 12, 0, 0]
4 [9, 9, 6, 11, 9, 6, 8, 6, 0, 0]
5 [7, 10, 6, 5, 6, 9, 12, 9, 0, 0]

I found the issue. But I am not able to find a way to get around it. I am calculating the weights on the subset indices, but I am passing the whole dataset to the train_loader. But I don’t think there is any way that is fast enough to get the subset of the train_set.

How are you calculating the weights based on the subset?
Would wrapping your train_set into torch.utils.data.Subset work?

Hi @ptrblck, Thanks for that advice. It kind of seems to work slightly better. But the distribution is as follows:

0 [5, 6, 9, 3, 6, 7, 5, 9, 6, 8]
1 [7, 8, 4, 7, 5, 6, 7, 6, 9, 5]
2 [6, 9, 6, 11, 4, 3, 8, 4, 6, 7]
3 [7, 10, 14, 3, 9, 6, 7, 6, 2, 0]
4 [9, 6, 7, 4, 5, 10, 0, 7, 8, 8]
5 [9, 8, 5, 9, 6, 4, 6, 7, 6, 4]

Could you iterate your DataLoader once and sum the statistics of each drawn batch?
For large numbers of samples, it should come close to the provided weights.
If not, there might be still some issues with creating the Subset etc.