Get two batches from dataloader

I am implementing semi-supervised learning for semantic segmentation. To mimic the semi-supervised learning, I need to split the training data into two, use first split with ground truth labels and other split without any labels. My training alternates between mini-batches from the two splits.

Is there a way to fetch two mini-batches from the dataloader without modifying the dataloader class. ?

The standard training looks like this

for batch_id, (img,mask,ohmask) in enumerate(trainloader):
    # Perform the training on batch index batch_id

Instead of fetching just the next batch, I want to fetch the next two batches.

Thanks in advance.

1 Like

What about doubling the batch size and split it in two?

That would work. Thanks :smiley: