My question is about PyTorch dataloaders.
Suppose I have a data_loader in a function. In calling this function once, I sample a number of batches from it. Then, I leave this function and later call it again. How should I make sure that batches that are sampled include unique values, a.k.a, sampling is done without replacement?
I appreciate your guidance.
One approach would be to create the iterator
manually and pass it around.
This would make sure that the next
operation will continue the sampling process and will not restart the DataLoader
iteration.
Here is a small example:
def fun(loader_iter):
for i in range(2):
try:
a = next(loader_iter)
print(a)
except StopIteration:
print("loader_iter is empty, re-create it!")
dataset = TensorDataset(torch.arange(10))
loader = DataLoader(dataset, batch_size=2)
loader_iter = iter(loader)
fun(loader_iter)
# [tensor([0, 1])]
# [tensor([2, 3])]
fun(loader_iter)
# [tensor([4, 5])]
# [tensor([6, 7])]
fun(loader_iter)
# [tensor([8, 9])]
# loader_iter is empty, re-create it!
# still empty!
fun(loader_iter)
# loader_iter is empty, re-create it!
# loader_iter is empty, re-create it!
loader_iter = iter(loader)
fun(loader_iter)
# [tensor([0, 1])]
# [tensor([2, 3])]
thank you very much. So using next() and iter() is the solution