Weighted sampling in torchtext

I am currently using torchtext to load a NLP dataset, and i want so oversample some classes (repeat samples of specific classes more often while training).
My code currently looks something like this:

tokenize = lambda x: x.split()

TEXT = Field(sequential=True, tokenize=tokenize, lower=True, unk_token = None)
LABEL = Field(sequential=False, use_vocab=True, unk_token = None)

fields = [('id', None), ('text', TEXT), ('label',LABEL)]

training, validation = TabularDataset.splits(
                        path = "./",
                        train = "training.csv", validation = "validation.csv",
                        format = "csv",
                        skip_header = True,
                        fields = fields)

TEXT.build_vocab(training, max_size = None)
LABEL.build_vocab(training)

train_iter, val_iter, _ = data.BucketIterator.splits((
                    training, validation, _),
                    batch_size = batchsize,
                    device = device,
                    sort_key = lambda x: len(x.text))

I have read here Balanced Sampling between classes with torchvision DataLoader that you can pass on a sampler argument to PyTorchs data.DataLoader function.

I there something similar that would work for my use case?

2 Likes

I am also stuck on the same place. Did you got a fix?