Hi guys, I’m wondering how to randomly get a mini-batch using a DataLoader like that:
train_loader = DataLoader(cifar10.Cifar10(train=True, image_transform=cifar10.transformer), shuffle=False, batch_size=BATCH_SIZE, num_workers=BATCH_SIZE)
But it seems that just can output batches one by one:
for (inputs, labels) in train_loader:
Is there any way enables me get i-th batch directly, or randomly get a batch? Thanks in advance.
2 Likes
ruotianluo
(Ruotian(RT) Luo)
May 27, 2018, 3:39am
2
You can create a torch.utils.data.Subset and feed that as input to the DataLoader.
Thank you, but would you please show me a detail example?
ruotianluo
(Ruotian(RT) Luo)
May 27, 2018, 8:10am
5
The reason why dataloader cannot fetch some batch in the middle is once the dataloader is created, it’s prefetching the data according to the internal order inside the dataloader.
So if you want to get i-th batch or get a random batch, you have to tell the loader what you want(so it’s not actually ith, or random)
The method I propose only works on master, https://pytorch.org/docs/master/data.html?highlight=subset#torch.utils.data.Subset .
The version for 0.4 could be using https://pytorch.org/docs/stable/data.html?highlight=subset#torch.utils.data.sampler.SubsetRandomSampler .
so, something like
sampler = SubsetRandomSampler(list(range(i*batch_size, (i+1)*batch_size)))
train_loader = DataLoader(cifar10.Cifar10(train=True, image_transform=cifar10.transformer), sampler=sampler, shuffle=False, batch_size=BATCH_SIZE, num_workers=BATCH_SIZE)
This is the simplest way I can think of.
1 Like