DataLoader: iterate through a subset of data instead of a the whole everytime it is called

Let’s say I have a torch dataloader = DataLoader(...) object. I don’t want to iterate through the whole dataset whenever I call for data, label in dataloader: in a function, so currently I use:

dataloader = DataLoader(...)
iter_dataloader = iter(dataloader)
batch = iter_dataloader.next()  # Set the first batch

def train_batch():
   data, label = batch
   prediction = model(data)
   # Do fancy things here
   try: 
      batch = iter_dataloader.next()  # Load the next batch
   except:
      iter_dataloader = iter(dataloader)  # if the iterator object reaches the end, reset the dataloader
      batch = iter_dataloader.next()  

for _ in range(N):
   train_batch()  # This function is called multiple times

For each call of train_batch(), I get a batch from the dataset, train the model, and load the next batch. If there is no batch left, I reset the DataLoader object.

Now to my questions:

  1. Is there a way to have a cleaner code? That is, I don’t want to use the iter and next method. Everytime I call it, it auto samples a batch from it and auto resets when it reaches the end. I hear about the Sampler, but I have not used it.
  2. Extension of above: instead of a batch, can I have like K batches or 1/K of the dataset’s size I am using?
  3. I want three sampling methods: (1) sampling batches from it in sequence (no shuffle), (2) sampling from it randomly (with and without replacement - shuffle), and (3) sampling from it such that the labels are equal. Is there a way to do this?