How to make DataLoader sample certain rows of data more than once

Hello,

The data that I am trying to train my model on is represented as sort of a list of unique values. To be more precise, the data is represented as a table of two columns. The first column contains the datapoint values, and the second column contains an integer >= 1 which represents its frequency in the dataset.

Data represented like:

|Value   |Frequency   |
-----------------------
|AAA     |1           |
|BBB     |1           |
|CCC     |3           |
|DDD     |1           |
|...     |...         |
|ZZZ     |1           |

Now, when training my model, I would like to incorporate the fact that certain values are more frequent in the dataset compared to others. Therefore, I cannot just feed my model the values in the first column (which would just be a list of unique values), because then every value would be presented to the model once, regardless of whether their frequency value was 1 or 100.

What I would like to do is to have my Dataloader randomly sample from my dataset, but for rows with an associated frequency value of N, sample that row N times before the end of the epoch. It would be as if instead of a list of unique values, I just had a big list of values with certain values repeated the correct number of times, and that was shuffled and I was sampling from that.

Make it equivalent to sampling from:

|Value   |
----------
|AAA     |
|BBB     |
|CCC     |
|CCC     |
|CCC     |
|DDD     |
|...     |
|ZZZ     |

What would be the best way of going about doing this?

Thank you very much in advance,
Yuta

Hello Yuta,

Have a look at WeightedRandomSampler here.

Best,
Andrei

1 Like

Dear Andrei,

Thank you for the reply!

This is very close to what I need, but I think in this case the sampler is probabilistic so there is no guarantee that rows with frequency N will be sampled N times, just that rows with frequency N are N times more likely to be sampled at any given time compared to rows with frequency 1.

Is there any way to have a dataloader sample rows with frequency N exactly N times over one epoch? I would like it to behave as if instead of there being one row with frequency=N, there are N rows of that same value in the dataset (as in the little tables I drew in my original post).

Thank you very much,
Yuta

Hi Yuta -

Sounds like we’re close. Just to clarify though, in your original post you said:

If one randomly “samples” from this there isn’t a guarantee that you’ll use CCC exactly 3 times. In expectation you will, but it’s probabilistic. When you say sampling from this, do you really mean iterating over this dataset but in a random order, such that CCC gets chosen exactly 3 times, but the order in which that happens (relative to other dataset items) is probabilistic?

If so, the most direct solution I can think of would be to subclass Dataset per this tutorial. You’d want to make the __len__ return the sum of occurrences (sum of your 2nd column) and you’ll need a bit of logic for the __getitem__(self, idx) to return what you’d like (I think the logic is just taking the cumulative sum of the 2nd column and return the first item from the 1st column where the value from the 2nd column is greater than idx). Finally you’d access this via a dataloader with shuffle=True, like this example in the tutorial dataloader = DataLoader(transformed_dataset, batch_size=4, shuffle=True, num_workers=0).

Let me know if I’ve understood your needs and if this puts you on the right track. If you get stuck I can try to implement the custom Dataset class on an artificial dataset.

1 Like

Another solution would be to create your own Sampler. (This is a modified version of RandomSampler)

import torch
import random
from torch.utils.data import DataLoader
from torch.utils.data.sampler import RandomSampler
from typing import Optional, Sized, Iterator

class MyRandomSampler(RandomSampler):
    r"""Sample elements randomly. 
    Not everything from RandomSampler is implemented.

    Args:
        data_source (Dataset): dataset to sample from
        frequency  (list): list of forbidden numbers
    """
    data_source: Sized
    forbidden: Optional[list]

    def __init__(self, data_source: Sized, frequency: list) -> None:
        assert (len(data_source)==len(frequency)), "Size of frquency list must match size of dataset."
        super().__init__(data_source)
        self.data_source = data_source
        self.freq = frequency
        self.idx = []
        self.refill()

    def _remove(self, to_remove) -> None:
        # Remove numbers just for this epoch
        for num in to_remove:
            if num in self.idx:
                self.idx.remove(num)

        self._num_samples = len(self.idx)

    def refill(self) -> None:
        # Refill the indices after iterating through the entire DataLoader
        self.idx = []
        for i, val in enumerate(freq):
            self.idx.extend([i]*(val))
        self._num_samples = len(self.idx)
        
    def __iter__(self) -> Iterator[int]:
        for _ in range(self.num_samples // 32):
            batch = random.sample(self.idx, 32)
            self._remove(batch)
            yield from batch
        yield from random.sample(self.idx, self.num_samples % 32)
        self.refill()

I have implemented it as an input in the constructor, however, you can take this from your dataset, or wherever this list is.

Here is a small test

# Fake Dataset
ds = torch.arange(5)
# Frequency that matches the length of the Dataset
freq = [1, 1, 3, 1, 2]

sampler = MyRandomSampler(ds, freq)

dl = DataLoader(ds, batch_size=4, sampler=sampler)

for batch in dl:
    print(batch)

As you can see, 2 has a frequency of 3, and 4 has a frequency of 2. This is just the index that will be passed to the __getitem__, so with your Dataset, this should return each element the number of times defined by the frequency.

# Output:
tensor([2, 3, 0, 2])
tensor([2, 4, 1, 4])

Hope this helps :smile:

1 Like

Dear Andrei,

Yes, you’re right, I didn’t use the correct terminology in my original post. Thank you for clearing it up!

Your suggestion to subclass a dataset and modify its __len__ and __getitem__ was just what I needed! Thank you. I’ve chosen to solve my issue this way, and now have a working solution.

I’ve marked your post as the solution.

Dear Matias,

Thank you for your reply! Your suggestion is also a valid solution to my problem, although I chose to mark Andrei’s one as the solution as that is what was more appropriate in my particular case. The suggestion and example are very much appreciated all the same!

1 Like