Hello, I am trying to sample specific parts from a large image (i.e. patches). I would like for the patch indexes (i.e. same parts of images) to be the same within a batch and changed (randomly re-sampled) after every batch.
With num_workers=0, this is straightforward as I can change the state of the dataset after each sample is loaded from the dataloader. However, with multiple workers it seems not as straightforward to alter the workers’ copies of the dataset after each batch.
Is there a good way to make this work with multiple workers? Should I make some kind of custom batch sampler that passes a random patch index sequence to retrieve when sampling? If so, any good resources on customizing the batch sampler?
Simplified example of how this works with num_workers=0
class exampleDataset(Dataset): def __init__(self,dirs,patch_state): self.dirs=dirs self.patch_state=patch_state def __len__(self): return len(self.dirs) def __getitem__(self,idx): image=read_image(dirs[idx]) patches=get_patches(image,self.patch_state) def set_patch_state(self,patch_state): self.patch_state=patch_state initial_patch_state=[0,3,5,33,72] dataset=exampleDataset(dirs,initial_patch_state) loader=DataLoader(dataset, batch_size = 8, num_workers=0) for i,data in enumerate(loader): #<do training> #change patches to sample new_patch_state=get_5_random_numbers()