Change Dataset values on each epoch with persistent workers

Hi!

I am working on a video-loading pipeline. Instead of loading all possible sequences from each video in the dataset, I want to randomly sample N sequences from each video (so all videos are seen at each epoch, but the epoch is a bit shorter). In order to use the full dataset, I would like that at each epoch these N sequences are potentially different. My current approach is holding a “mapping” variable that holds a randomly selected set of indices to be used in the current epoch. What I need is that, at the end of the epoch, some function is called so this mapping is recomputed (and that all workers have the same mapping).

I wrote the following simplified version of what I want to achieve:

import torch 
class Dummy(torch.utils.data.Dataset):
    def __init__(self,num_sequences, dim):
        self.data = torch.stack([torch.ones(dim)*x for x in range(num_sequences*2)])
        self.num_sequences = num_sequences
        self.mapping = list()
        self.make_mapping()
    
    def __len__(self):
        return self.num_sequences
    
    def __getitem__(self,idx):
        return self.data[self.mapping[idx]]
    
    def make_mapping(self):
        self.mapping = sorted(torch.randperm(len(self.data))[:self.num_sequences].numpy())
        
dset = Dummy(10,1)
dload = torch.utils.data.DataLoader(dset,batch_size=5, num_workers=2,persistent_workers=True)
for e in range(5):
    print(f"Epoch {e}")
    for b in dload:
        print(b)
    dload.dataset.make_mapping()

The problem is when persistent_workers=True, as each worker holds its own copy of the Dataset and this call to make_mapping() seems to have no effect.

I have been searching and while I find people with similar problems (generally related to cropping or different resolution inputs), I do not seem to find a way to fit their solutions (workarounds to actually changing Dataset variables). I also found this, but seems like too much hassle to just shuffle the data.

I would think there should be a way to handle some event at the end of the epoch such that the make_mapping function is called. I tried to do something in these lines by keeping count of the number of times getitem is called, so each worker calls make_mapping if the counter reaches num_sequences // num_workers. This kinda works but as each worker makes its own shuffle, then the same sequence might be seen multiple times in a single epoch.

Having persistent_workers=False definitely works, but I was wondering if there could be a different way to achieve this that I am missing. I took a look at Samplers, but I have never used them and I am not sure if I could achieve this with them.

Thanks a lot in advance!