Loading many images in getitem but returning only one at a time

I have a directory full of tfrecords with each tfrecord corresponding to an image which has been subdivided into many tiles, and their labels. For training, I would like to turn this into batches of let’s say 8 random tiles, which can be sampled from various images. My idea is to load as many tiles as possible into RAM in the getitem method (using dareblopy to decode the tfrecords), then to somehow sample them in the get_item function (or alternatively in a custom collate function) and return or possibly yield one batch at a time.

Currently, I first create an array of filenames where each row contains all the filenames to read at a time, then I pass this array into the custom Dataset class, and let the index in get_item refer to the number of rows in the array. Here is sample code

class TfrecordDataset(Dataset):

    def __init__(self, ids, tfrecords_path):

        self.ids = ids
        self.tfrecords_path = tfrecords_path

    def __getitem__(self, idx):

        # Get list of image IDs.
        image_ids = self.ids[:,idx]

        # Initialize parser for TFRecords.
        features = {'tile_id': db.FixedLenFeature([], db.string),
                    'image': db.FixedLenFeature([], db.string),
                    'label_image': db.FixedLenFeature([], db.string)}
        parser = db.RecordParser(features)


        tile_list = []

        for image_id in image_ids:
            # Decode the tfrecords one by one
            tfr_db = db.RecordReader(self.tfrecords_path + "/" + str(image_id) + "/" + ".tfrecord")
            tfr_length = tfr_db.get_metadata()[2]

            # Gather all their tiles into tile_list
            for i in range(tfr_length):
                example = next(tfr_db)
                tile_id, tile, label = parser.parse_single_example(example)
                image_tile = Image.open(io.BytesIO(tile[0]))
                label_tile = Image.open(io.BytesIO(label[0]))

                image_tile = TF.to_tensor(image_tile)
                label_tile = TF.to_tensor(label_tile)

                tile_list.append((image_tile, label_tile, image_id))

        return tile_list

    def __len__(self):
        return self.ids.shape[1]

This does what I want in the sense that it successfully loads many images and combines their tiles into a single list which can easily be shuffled, but the get_item return becomes a huge list of thousands of images and I would much prefer if the Dataset iterator could return reasonably sized batches, like 8 tiles or so, since the current approach becomes strange when passing into a dataloader. Is there a way to do this easily using pytorch? If there is a method which accomplishes the same thing without bothering with preparing the array etc., then that’s of course fine too.