I am currently using torchtext to load a NLP dataset, and i want so oversample some classes (repeat samples of specific classes more often while training).
My code currently looks something like this:
tokenize = lambda x: x.split() TEXT = Field(sequential=True, tokenize=tokenize, lower=True, unk_token = None) LABEL = Field(sequential=False, use_vocab=True, unk_token = None) fields = [('id', None), ('text', TEXT), ('label',LABEL)] training, validation = TabularDataset.splits( path = "./", train = "training.csv", validation = "validation.csv", format = "csv", skip_header = True, fields = fields) TEXT.build_vocab(training, max_size = None) LABEL.build_vocab(training) train_iter, val_iter, _ = data.BucketIterator.splits(( training, validation, _), batch_size = batchsize, device = device, sort_key = lambda x: len(x.text))
I have read here Balanced Sampling between classes with torchvision DataLoader that you can pass on a
sampler argument to PyTorchs
I there something similar that would work for my use case?