Randomness at the batch level

Hello, I am trying to sample specific parts from a large image (i.e. patches). I would like for the patch indexes (i.e. same parts of images) to be the same within a batch and changed (randomly re-sampled) after every batch.

With num_workers=0, this is straightforward as I can change the state of the dataset after each sample is loaded from the dataloader. However, with multiple workers it seems not as straightforward to alter the workers’ copies of the dataset after each batch.

Is there a good way to make this work with multiple workers? Should I make some kind of custom batch sampler that passes a random patch index sequence to retrieve when sampling? If so, any good resources on customizing the batch sampler?

Simplified example of how this works with num_workers=0

class exampleDataset(Dataset):
    def __init__(self,dirs,patch_state):
        self.dirs=dirs
        self.patch_state=patch_state
    def __len__(self):
        return len(self.dirs)
    def __getitem__(self,idx):
        image=read_image(dirs[idx])
        patches=get_patches(image,self.patch_state)
    def set_patch_state(self,patch_state):
        self.patch_state=patch_state

initial_patch_state=[0,3,5,33,72]
dataset=exampleDataset(dirs,initial_patch_state)
loader=DataLoader(dataset, batch_size = 8, num_workers=0) 

for i,data in enumerate(loader):
    #<do training>
    #change patches to sample
    new_patch_state=get_5_random_numbers()

Thanks

If you are using multiple workers, each worker will use a separate seed defined by the main process RNG and the worker id. I guess you could try to use this worker information inside your Dataset.__getitem__ to either re-seed the code or change the patch_state accordingly.
To use the states use torch.utils.data.get_worker_info().

Thank you for the response. I was looking into get_worker_info(), but I think a problem is that there is no concept of the batch in getitem. So I’m not sure how you’d be able to only change the state after a batch, rather than after each getitem call.

You are right that the default sampler will pass each batch index individually (a BatchSampler could pass all indices of the batch to __getitem__). However, you should also hit the same limitation for num_workers=0 so how are you dealing with it now?

With num_workers=0, I can change the patch_state at the end of each patch (as in the example code in the original post above), so that’s how I was handling it. But yes I will proceed with using a BatchSampler; it seems like this is the best approach.

__getitem__() will accept a second argument for the patch indexes to return. And within each batch in the batch sampler, the patch indexes will be the same. So in this case there would no longer be a self.patch_state variable for the dataset