How to give additional parameters to the collate_fn passed in data loader?

I want to use a custom collate function which collates data based on the arguments passed to it. I am not sure if it’s possible to pass the arguments to collate function internally from the data loader.

I don’t know where this argument would be coming from, but in case you want to set an additional argument in your collate_fn while creating the DataLoader, you could use a lambda approach.
If you want to pass this argument from the Dataset, you could return it and use it in your custom collate_fn.

Here is a simple approach showing these approaches:

# simple use case
def my_collate(batch):
    x = torch.stack([a for a, b in batch])
    y = torch.stack([b for a, b in batch])
    return x, y

dataset = TensorDataset(torch.arange(10), torch.arange(10, 20))
loader = DataLoader(dataset, batch_size=5, collate_fn=my_collate)

for data, target in loader:
    print(data, target)


# with argument
def my_collate(batch, arg):
    print(arg)
    x = torch.stack([a for a, b in batch])
    y = torch.stack([b for a, b in batch])
    return x, y

loader = DataLoader(dataset, batch_size=5, collate_fn=lambda batch: my_collate(batch, arg="myarg"))

for data, target in loader:
    print(data, target)
3 Likes

This sadly won’t work for DDP, since lambdas can’t be pickled. Any workarounds?

So, I gather there are no workarounds. Everything that needs to be passed to the collate_fn must be explicitly included in the dataset.

The good news is that this actually relatively memory-cheap to do. Just make sure that you are not making copies of whatever you need to have at hand in collate_fn in your dataset implementation – share the objects.