If I understand correctly, each worker process runs the __get_item__
of the Dataset
on the indices it has received (a subset of the batch indices), and creates a list of outputs. What is unclear to me is when the collate_fn
is called to assemble the list of outputs into a batch. It would be reasonable for this function to be called by the main process, once it has received all the outputs corresponding to a single batch from the worker processes. However, the documentation states that any custom collate_fn
needs to be picklable, and thus I assume that each worker process calls collate_fn
on the portion of the batch it has processed, and these portions are then sent to the main process, which then stacks the tensors? It doesn’t make much sense, but I can’t think of another reason why collate_fn
has to be declared at the top level and be picklable. Can someone clarify when/where collate_fn
is called and why it needs to be pickled? Can it be a member function of the Dataset
, and use member variables?