I want to use a custom collate function which collates data based on the arguments passed to it. I am not sure if it’s possible to pass the arguments to collate function internally from the data loader.
I don’t know where this argument would be coming from, but in case you want to set an additional argument in your collate_fn
while creating the DataLoader
, you could use a lambda
approach.
If you want to pass this argument from the Dataset
, you could return it and use it in your custom collate_fn
.
Here is a simple approach showing these approaches:
# simple use case
def my_collate(batch):
x = torch.stack([a for a, b in batch])
y = torch.stack([b for a, b in batch])
return x, y
dataset = TensorDataset(torch.arange(10), torch.arange(10, 20))
loader = DataLoader(dataset, batch_size=5, collate_fn=my_collate)
for data, target in loader:
print(data, target)
# with argument
def my_collate(batch, arg):
print(arg)
x = torch.stack([a for a, b in batch])
y = torch.stack([b for a, b in batch])
return x, y
loader = DataLoader(dataset, batch_size=5, collate_fn=lambda batch: my_collate(batch, arg="myarg"))
for data, target in loader:
print(data, target)
This sadly won’t work for DDP, since lambdas can’t be pickled. Any workarounds?
So, I gather there are no workarounds. Everything that needs to be passed to the collate_fn
must be explicitly included in the dataset.
The good news is that this actually relatively memory-cheap to do. Just make sure that you are not making copies of whatever you need to have at hand in collate_fn
in your dataset implementation – share the objects.