Let’s say I have a torch dataloader = DataLoader(...)
object. I don’t want to iterate through the whole dataset whenever I call for data, label in dataloader:
in a function, so currently I use:
dataloader = DataLoader(...)
iter_dataloader = iter(dataloader)
batch = iter_dataloader.next() # Set the first batch
def train_batch():
data, label = batch
prediction = model(data)
# Do fancy things here
try:
batch = iter_dataloader.next() # Load the next batch
except:
iter_dataloader = iter(dataloader) # if the iterator object reaches the end, reset the dataloader
batch = iter_dataloader.next()
for _ in range(N):
train_batch() # This function is called multiple times
For each call of train_batch()
, I get a batch from the dataset, train the model, and load the next batch. If there is no batch left, I reset the DataLoader object.
Now to my questions:
- Is there a way to have a cleaner code? That is, I don’t want to use the
iter
andnext
method. Everytime I call it, it auto samples a batch from it and auto resets when it reaches the end. I hear about theSampler
, but I have not used it. - Extension of above: instead of a batch, can I have like
K
batches or1/K
of the dataset’s size I am using? - I want three sampling methods: (1) sampling batches from it in sequence (no shuffle), (2) sampling from it randomly (with and without replacement - shuffle), and (3) sampling from it such that the labels are equal. Is there a way to do this?