Yes exactly! At the end of the process, I get one batch, and the sampled value k
.
I achieved this by delaying I/O and preprocessing until the collate function, but it’s not very flexible.
at the moment, I do:
def collate_data(batch):
k = np.random.randint(0, 10)
batch = [sample[k] for sample in batch]
data = load_batch(batch)
augmented_data = augment(data)
return augmented_data, k
This seems a bit klunky, but it works. Would be nice if I could switch off augmentation using a flag. Maybe I should use a lambda function?
something like:
DataLoader(..., collate_fn=lambda batch: collate_fn(batch, augmentFlag))
I’d be interested to know if there are other ways of achieving this.
Thanks!
update:
lambda functions is how I solved this. It works, and it’s clunky somewhat, but ah well.
def collate_data(batch, augmentFlag):
k = np.random.randint(0, 10)
batch = [sample[k] for sample in batch]
data = load_batch(batch)
if augmentFlag:
data = augment(data)
return data, k
The lambda functions would be:
val_dataset = DataLoader(..., collate_fn=lambda batch: collate_data(batch, augmentFlag=False))
train_dataset = DataLoader(..., collate_fn=lambda batch: collate_data(batch, augmentFlag=True))