Iterate through loader until a condition is met

I need to build a data loader which iterates through the loader until a certain sample size is met. For example, if my loader looks like below.

import torch
from torch.utils.data import DataLoader, Dataset


class CustomLoader(Dataset):
  def __init__(self) -> None:
    super().__init__()

    self.all_files = torch.randn((250))

  
  def __len__(self):
    return len(self.all_files)

  def __getitem__(self, idx):
    rnd = self.all_files[idx]
    x = torch.randn((21, 20, 4)) * rnd
    target = idx

    return x, target

if __name__ == "__main__":
  dataset = CustomLoader()
  loader = DataLoader(dataset, batch_size=8, shuffle=True)

  for b, (x, target) in enumerate(loader, 0):
    bat, peopl, tim, feat = x.size() # 8 * 21 * 20 * 4

    # Flatten x dimension
    x = x.view(bat * peopl, tim, feat) # 168 * 20 * 4

    # Some Condition. Output size varies
    rand_el = torch.randint(0, 22, size=(1,)).item()
    x = x[:rand_el]

    print(x.size()) # 13 * 20 * 4

My x is of dim 8 X 21 X 20 X 4, where 8 is my batch size. Once I have the data, I combine the batch and 21 dimension which then becomes 168 X 20 X 4. I apply some condition and filter out the data. So my final dimensions can then be 13 X 20 X 4. I need to iterate through next batch until my data is append to 32 batch size.

What is the best way to go about this problem. Applying this filter in dataset class itself outputs variable length input which cannot collate my data effectively. A custom collate function does not allow me to iterate through next batch as far as I know. Please advise.