Hi, my setup is rather straightforward: my DataLoader’s getitem functions opens up an Image, and crops it using pre-determined coordinates. I want these coordinates to be the same for a whole batch of images, but then change for each subsequent batch.
Unfortunately, I can’t find a way to do that with DataLoader workers, as I have no means of telling them that a full batch has been completed and that they should move on to new crop coordinates (which are random but could be pre computed).
The only hacks I can see around it for the moment would be to rely on reading a file (with all the multiprocessing locking issues that could arise) or to use the time to roughly synchronize batches (e.g. change crop coordinates every 5min on the clock etc.).
I feel like there should be an easy way for me to communicate with all workers to tell them to update their crop coordinates, but I can’t seem to find it.
Thanks in advance,
If you look at the details in the documentation (and know where to look), you see that each batch is compiled by a single worker (it runs the collate function that makes batches from lists of examples).
With this, you have two options to get “batch-level” processing:
- Have the dataset build and return batches (this is useful also e.g. when you want to group data by size when you have varying sizes) and use with a dummy batch_size=1 in the DataLoader,
- have the dataset return uncropped data and crop while copying them into a joint tensor in collate_fn (or don’t do auto batching and do it in the main loop).
Thanks a lot, I absolutely missed the fact that workers don’t contribute to the same batch but rather build their own in parallel.
This does solve my issue as now each worker can simply crop the data with a random crop in its collate function and then call the default collate function.
In fact, this certainly also means that the worker could simply do the cropping in the getitem and update the random crop specified by the dataset, as it’s not shared with the other workers, and thus only have to call the function to update the crop coordinates in the collate function.
In the end this second option is much easier, and without changing any of my existing code, I managed to get it working perfectly by having the worker_init_fn and collate_fn do the crop position update:
from torch.utils.data._utils.collate import default_collate
# Make sure workers don't start with same crop
dataset = torch.utils.data.get_worker_info().dataset
print("Collating batch that was cropped with with coordinates:", dataset.top, dataset.left)
# Change random crop for next batch
Thanks again, this is the elegant solution I was looking for.