Video frame pairs dataset

I am trying to build a dataset that produces random pairs of frames from the same video sequence. I have two alternative implementations each with its own drawback.

One class deriving from dataset produces tensors of 15 frames then the getitem function randomly samples 2 without replacement. This leads to a very small number of batches and a lot of “wasted” frames per epoch. However one benefit is that each epoch “sees” different pairs of frames.

The other dataset class generates pairs of frames directly in the constructor by permutation. The benefit is that all the frames are used for each epoch. The drawback is that all epochs will see the same fixed frame pairs.

Is there a way to implement a dataset/dataloader which combines the benefits of both: creates new frame pairs for each epoch and uses all the frames for each epoch without waste.

Below is the code for the second approach:

class TwoFramesDataset(Dataset):
    def __init__(self, directory, dataset="train", transform=None):
        self.transform = transform
        self.instances, self.labels = self.read_dataset(directory, dataset)
        self.instances = torch.from_numpy(self.instances)
        self.labels = torch.from_numpy(self.labels)

    def __len__(self):
        return self.instances.shape[0]

    def __getitem__(self, idx):
        instances = self.instances[idx]
        transformed_instances = []
        for _, frames in enumerate(instances):
            transformed_frames = [self.get_frame(frames, 0), self.get_frame(frames, 1)]
            transformed_instances.append(torch.stack(transformed_frames))
        sample = {
            "instance": torch.stack(transformed_instances).squeeze(),
            "label": self.labels[idx]
        }
        return sample

    def get_frame(self, frames, i):
        frame = frames[i, :, :].unsqueeze(0)
        frame = torchvision.transforms.ToPILImage()(frame / 255)
        if self.transform is not None:
            frame = self.transform(frame)
        return frame.squeeze()

    def zero_center(self, mean):
        self.instances -= float(mean)

    def read_dataset(self, directory, dataset="train", mean=None):
        if dataset == "train":
            filepath = os.path.join(directory, "train.p")
        elif dataset == "dev":
            filepath = os.path.join(directory, "dev.p")
        else:
            filepath = os.path.join(directory, "test.p")
        videos = pickle.load(open(filepath, "rb"))

        instances = []
        labels = []
        current_block = []
        for video in videos:
            frames = video["frames"]
            perms = torch.randperm(len(frames))
            for i in range(len(frames)//2):
                current_block.append(frames[perms[i*2]])
                current_block.append(frames[perms[i*2+1]])
                current_block = np.array(current_block)
                instances.append(current_block.reshape((1, 2, 60, 80)))
                labels.append(CATEGORY_INDEX[video["category"]])
                current_block = []

        instances = np.array(instances, dtype=np.float32)
        labels = np.array(labels, dtype=np.uint8)

        self.mean = np.mean(instances)

        return instances, labels