How to select a specific batch from dataset?

I have a dataset with samples by year-month.

I’m able to create a custom Dataset which does a wonderful job of fetching data throughout the dataset, however I now need to be able to retrieve batches from the DataLoader in year-month batches.

I’m not sure if I should be creating a custom Sampler (or BatchSampler), and how that will specifically select only a batch of a certain year-month to iterate over as it pulls out samples.

I begin with a pandas DataFrame which I use when I init my custom torch.Dataset. I need to be able to iterate over all the year-month groups, returning each as it’s own batch of samples.