How to obtain a batch randomly from DataLoader?

Hi guys, I’m wondering how to randomly get a mini-batch using a DataLoader like that:

train_loader = DataLoader(cifar10.Cifar10(train=True,  image_transform=cifar10.transformer), shuffle=False, batch_size=BATCH_SIZE, num_workers=BATCH_SIZE)

But it seems that just can output batches one by one:

    for (inputs, labels) in train_loader:

Is there any way enables me get i-th batch directly, or randomly get a batch? Thanks in advance.

2 Likes

You can create a torch.utils.data.Subset and feed that as input to the DataLoader.

Thank you, but would you please show me a detail example?:relaxed:

The reason why dataloader cannot fetch some batch in the middle is once the dataloader is created, it’s prefetching the data according to the internal order inside the dataloader.

So if you want to get i-th batch or get a random batch, you have to tell the loader what you want(so it’s not actually ith, or random)

The method I propose only works on master, https://pytorch.org/docs/master/data.html?highlight=subset#torch.utils.data.Subset.

The version for 0.4 could be using https://pytorch.org/docs/stable/data.html?highlight=subset#torch.utils.data.sampler.SubsetRandomSampler.

so, something like

sampler = SubsetRandomSampler(list(range(i*batch_size, (i+1)*batch_size)))
train_loader = DataLoader(cifar10.Cifar10(train=True,  image_transform=cifar10.transformer), sampler=sampler, shuffle=False, batch_size=BATCH_SIZE, num_workers=BATCH_SIZE)

This is the simplest way I can think of.

1 Like