How to use collate_fn()

I am not sure with what collate_fn does.
is there any example that helps understanding what it does?


You can use your own collate_fn to process the list of samples to form a batch.
The batch argument is a list with all your samples. E.g. if you would like to return variable-sized data, have a look at this thread.


so as ptrblck said the collate_fn is your callable/function that processes the batch you want to return from your dataloader. e.g.

    def collate_fn(batch):

in my case of batch_size=4 will return a list of size four. Lets check it:

<class 'list'>

I have recently answered some other post with a similar question. But basically, the collate_fn receives a list of tuples if your __getitem__ function from a Dataset subclass returns a tuple, or just a normal list if your Dataset subclass returns only one element. Its main objective is to create your batch without spending much time implementing it manually. Try to see it as a glue that you specify the way examples stick together in a batch. If you don’t use it, PyTorch only put batch_size examples together as you would using torch.stack (not exactly it, but it is simple like that).

The following code I wrote on this post should help you grasp the real understanding. It pads sequences with 0 until the maximum sequence size of the batch, that is why I need the collate_fn, because a standard batching algorithm (simply using torch.stack) won’t work in my case, and I need to manually pad different sequences with variable length to the same size before creating the batch.


where is the parameter ‘batch’ come from?

Batch parameter is implicit in the Dataloader function.
You just have to pass the collate function name to collate_fn Dataloader parameter.

Here’s an example:

loader_collate = DataLoader(
    dataset, shuffle=True, batch_size=5, collate_fn=collate_fn)

very cool example @Paulo_Mann , also possible to it with pytorch pad_sequences function ( which I am not sure was there in 2018:))

1 Like