I am currently using torchtext to load a NLP dataset, and i want so oversample some classes (repeat samples of specific classes more often while training).
My code currently looks something like this:
tokenize = lambda x: x.split()
TEXT = Field(sequential=True, tokenize=tokenize, lower=True, unk_token = None)
LABEL = Field(sequential=False, use_vocab=True, unk_token = None)
fields = [('id', None), ('text', TEXT), ('label',LABEL)]
training, validation = TabularDataset.splits(
path = "./",
train = "training.csv", validation = "validation.csv",
format = "csv",
skip_header = True,
fields = fields)
TEXT.build_vocab(training, max_size = None)
LABEL.build_vocab(training)
train_iter, val_iter, _ = data.BucketIterator.splits((
training, validation, _),
batch_size = batchsize,
device = device,
sort_key = lambda x: len(x.text))
I have read here Balanced Sampling between classes with torchvision DataLoader that you can pass on a sampler
argument to PyTorchs data.DataLoader
function.
I there something similar that would work for my use case?