Sample two batches at a time with DataLoader

I have one dataloader. I would get to get two batches of data from it at once. The only mechanism I could find online for using the dataloader would only allow me to get one batch of data at a time.
“for batch_idx, (data, target) in enumerate(train_loader)”

Do you have any suggestions for getting around this?
Thank you!

One way to do is to increase the batch size by 2 and split the batched output into 2 small batches :slight_smile:

torch.split() is the function to do that, right? Thanks!

Yes. It gives a tuple of two batches

Alternative to loading a batch twice the size and splitting it, you could cast the DataLoader as an iterator and use the next function (or .next() method in Python 2.7). E.g.,

...
train_loader = DataLoader(dataset=train_dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          num_workers=4)

train_loader_iter = iter(train_loader)
for batch_idx, (features, targets) in enumerate(train_loader_iter):
     
    # do sth with the first batch (features & targets)

    #fetch second batch
    features_2, targets_2 = next(train_loader_iter)
    batch_idx += 1

2 Likes