Inheriting the PyTorch DataLoader class for processing "per-worker" outputs

Can we inherit the DataLoader class? If so, are there any specific restrictions to it?
I know we can do so for the Dataset class, but I need to know specifically about the DataLoader.

Specifically I need to process the “per-worker” (DataLoader multiprocessing) return values from the dataset __get_item__ calls and return the processed output to my main program.
The only way I see it is to inherit the DataLoader module, modify the worker classes to include my desired changes.
Are there any other “cleaner” ways to do this rather than the above?
Also not sure if collate_fn could work here?

Would really appreciate the help!
Thanks…

I think a custom collate_fn should work as you would get all samples from the current batch loaded in this worker process and could then process it further if needed.

@ptrblck Thanks for the response…
So, I guess whatever custom code I include in collate_fn would work on the samples at a “batch” level.
Is there any alternative – if I wanted the processing to be done at a single “sample” level?
Or should I employ unpacking the batch and work on each of the samples?

If you want to process each sample, you could perform these operation in the Dataset.__getitem__ method. I’m unsure about the worker-specific processing idea, but in case you need to get the worker id you could use:

worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
    worker_id = worker_info.id

inside the __getitem__.

Thanks for your help…! :slight_smile:
I was able to work it out with a combination of namedtuple returns from the __get_item__ calls with the specific worker_id calls (as mentioned above) and a custom collate_fn implementation.