Multiprocessing with custom collate_fn problems: AttributeError: Can't pickle local object

Hi, I implemented my own collate_fn as I need to pad data up to a variable length. The length up to which the batch has to be padded is also determined in this function.

When I set the num_workers > 0 in the DataLoader, it runs fine on my cpu. However, und gpu I get the following error:
RuntimeError: cuda runtime error (3) : initialization error at /pytorch/aten/src/THC/THCCachingAllocator.cpp:507
I read several similar threads, where it is recommended to set the start method for multiprocesing like this:
Now if I do that, I get the following error on both my cpu and gpu:
AttributeError: Can't pickle local object 'MyDataset.get_collate_fn.<locals>.collate_fn'
Maybe important to know is that I pass the collate_fn to the Dataloader like this:

class MyDataset(Dataset)
    def get_collate(device):
        def collate_fn(batch):
            batch =
        return collate_fn

data_loader = DataLoader(MyDataset(), collate_fn=MyDataset.get_collate(device))

same issue. Do you have a solution?


Same thing here, only with set_start_method(‘spawn’) or ‘forkserver’, ‘fork’ start method is fine. Pytorch 1.2, Linux