DataLoader: Load Sequence of images

I have to load sequence of images using a custom PyTorch Dataloader. However, all the code and examples I could find online didn’t fit my case.
I have 22 folders (1 folder corresponds to 1 video). Each of the folder contains the csv file to load the data as well as around 50 subfolders. Each subfolder contains 9 images, so my sequence length is 9 (just some contain 18 images as seen in the code below).
I got the code to load the data from a colleague but he didn’t use a custom DataLoader function that can be used by PyTorch and I failed trying to create one.
Below is the code how he loaded the data. I’d really appreciate it if someone could help me with this.

# To store sequences
def cache(f):
    store = {}

    def _impl(lst, id, cam):
        try:
            return store[(id, cam)]
        except:
            v = f(lst, id, cam)
            store[(id, cam)] = v
            return v

    return _impl


# Sequence object, stores cases (images)
class Sequence(object):
    def __init__(self, id, cam):
        self.id = id
        self.cam = cam
        self.cases = []

    @staticmethod
    @cache
    def get(store, id, cam):
        seq = Sequence(id, cam)
        store.append(seq)
        return seq

# A case, contains a path and the label
class Case(object):
    def __init__(self, id, label, path):
        # Doubtful
        if label == 'truelabel*':
            label = 'truelabel'
        # Almost surely not a truelabel
        elif label == 'finding':
            label = 'normal'
        # Two of them consider normal
        elif label == 'double truelabel':
            label = 'truelabel'
        
        self.id = id
        self.label = label
        self.path = path
        self.cls = ('normal', 'truelabel').index(label)


# Reads the data
class Reader(object):
    IDS = set()
    VIDS = {}
    LOCK = threading.Lock()
    SEQ_SIZE = 9

    def __init__(self, path, discard=[], extensions=['png']):
        self.train_idxs = []
        self.train = []
        self.test_idxs = []
        self.test = []
        self.path = path
        
        # In case this is read in multiple threads, like tensorflow might do
        with Reader.LOCK:
            # Just read the data once, if it's already done, exit
            if Reader.IDS:
                return
            
            # Iterate path to find video IDs
            for dirname in os.listdir(path):
                if os.path.isfile(dirname):
                    continue

                if dirname in discard:
                    continue

                Reader.IDS.add(dirname)

            # Iterate each ID
            for vid_id in Reader.IDS:
                # Read data CSV
                vid_path = os.path.join(path, vid_id)
                df = pd.read_csv(os.path.join(vid_path, vid_id + '.csv'))

                # Iterate dataframe to parse all cases
                Reader.VIDS[vid_id] = []
                for _, row in df.iterrows():
                    case_id = row.id
                    cam_id = row.cam_id
                    seq_id = row.seq_id

                    # Label might be empty, in which case it is normal
                    label = row.label
                    label = label if isinstance(label, str) else 'normal'

                    # Get sequence and populate a case
                    seq = Sequence.get(Reader.VIDS[vid_id], seq_id, cam_id)
                    seq.cases.append(
                        Case(
                            case_id, 
                            label, 
                            os.path.join(path, vid_id, '{}_{}'.format(seq_id, cam_id), '{}_{}.png'.format(case_id, cam_id))
                        )
                    )

                    # Some sequences have > 9 images
                    # All such cases end up having 18 images, which is the next 'if', but we log it just in case
                    if len(seq.cases) == Reader.SEQ_SIZE + 1:
                        logging.debug('Sequence {}_{} has {} > {} cases'.format(seq_id, cam_id, len(seq.cases), Reader.SEQ_SIZE))

                    # If a sequence has double the amount of cases (18 vs 9), split it into two different sequences
                    if len(seq.cases) == Reader.SEQ_SIZE * 2:
                        logging.debug('Splitting sequence {}_{} into 2'.format(seq_id, cam_id))
                        seq2 = Sequence.get(Reader.VIDS[vid_id], seq_id + 1e6, cam_id)
                        seq2.cases = seq.cases[9:]
                        seq.cases = seq.cases[:9]

                
    # Splits all data into two sets
    def split(self, partition, n_splits=5, shuffle=False, random_state=None):
        assert partition < n_splits, 'partition must be lower than n_splits'
        
        # Split train/test ids
        ids = np.array(list(Reader.IDS))
        kf = KFold(n_splits=n_splits, shuffle=shuffle, random_state=random_state if shuffle else None)
        splits = kf.split(ids)
        for _ in range(partition):
            next(splits)
        
        # Assign sequences to train/test
        self.train_idxs, self.test_idxs = next(splits)
        self.train = [seq for id in self.train_idxs for seq in Reader.VIDS[ids[id]]]
        self.test = [seq for id in self.test_idxs for seq in Reader.VIDS[ids[id]]]


    # Iterates the given data (train/test) and outputs  the images and labels
    def _yield(self, data, shuffle):
        if shuffle:
            random.shuffle(data)
        
        for seq in data:
            ims = np.empty((Reader.SEQ_SIZE, 256, 256, 3), dtype=np.uint8)
            labels = np.empty((Reader.SEQ_SIZE), dtype=int)

            for i, case in enumerate(seq.cases):
                ims[i] = imread(case.path)
                labels[i] = case.cls
            
            yield ims, labels
            
    # Returns a generator containing the train data
    def train_data(self, shuffle):
        return self._yield(self.train, shuffle)
    
    # Returns a generator containing the test data
    def test_data(self, shuffle):
        return self._yield(self.test, shuffle)

The batches get generated separately:

def get_batch(data, batch_size):
    batch_x = []
    batch_y = []

    for _ in range(batch_size):
        ims, labels = next(data)
        batch_x.append(ims[np.newaxis, ...])
        batch_y.append(labels[np.newaxis, ...])

    batch_x = np.concatenate(batch_x, axis=0)
    batch_y = np.concatenate(batch_y, axis=0)

    return batch_x, batch_y

And then I call it like this:

    train_data = reader.train_data(shuffle=shuffle) 
    for i in range(int(len(reader.train)/batch_size)):
        images, labels = get_batch(train_data, batch_size)