Extract frame number and patch number for custom dataset

I have a custom dataset of 256x256px video frames. To limit my memory usage, I am using the filepath as a pointer. However, I would like to split these frames into 16x16 non-overlapping patches to feed into a Transformer, and as such am trying to compute the frame number as well as the patch number within the frame beforehand from the full frame * patch number.


class Dataset(torch.utils.data.Dataset):
    def __init__(self, directory='../data/*', get_subdirs=True, size=(16,16), max_ctx_length=4096):
        self.data = glob.glob(directory)
        if get_subdirs:
            data_temp = []
            for i in self.data:
                fata = glob.glob(i+"/*")
                fata.sort(key=lambda r: int(''.join(x for x in r if (x.isdigit()))))

        self.data = data_temp
        self.max_ctx_length = max_ctx_length
        self.size = size
    def __len__(self):
        return len(self.data)*self.size[0]-self.max_ctx_length-1
    def __getitem__(self, key):
        frame_start = int(np.ceil(key / self.size[0]))
        frame_end = int(np.ceil((key+self.max_ctx_length+1) / self.size[0]))

        frames = self.data[frame_start:frame_end]

        patch_start = int(np.mod(key, self.size[0]))
        patch_end = int(np.mod(key+self.max_ctx_length+1, self.size[0]))
        print(patch_start, patch_end)
        patches = []

        for im in range(len(frames)):
            image = (Tvio.read_image(frames[im], mode=Tvio.ImageReadMode.RGB).float() / 255).unsqueeze(0)
            ppf = F.unfold(image, self.size, stride=self.size).transpose(1,2).split(1, 1)
            if im == 0:
                if not(im==len(frames)-1):
                    ppf = ppf[patch_start:]
                    ppf = ppf[patch_start:patch_end]
            elif im == len(frames)-1:
                ppf = ppf[:patch_end]
        data_x = patches[:-2]
        data_y = patches[1:-1]
        return torch.cat(data_x, dim=1).squeeze(0), torch.cat(data_y, dim=1).squeeze(0)           

if __name__ == "__main__":
    dataset = Dataset()

Expected Output:
A tensor of shape [4096, 768].

Actual Output:
torch.Size([65279, 768])