How to use collate_fn()

Hi,
I am not sure with what collate_fn does.
is there any example that helps understanding what it does?

15 Likes

You can use your own collate_fn to process the list of samples to form a batch.
The batch argument is a list with all your samples. E.g. if you would like to return variable-sized data, have a look at this thread.

21 Likes

so as ptrblck said the collate_fn is your callable/function that processes the batch you want to return from your dataloader. e.g.

    def collate_fn(batch):
        print(type(batch))
        print(len(batch))

in my case of batch_size=4 will return a list of size four. Lets check it:

<class 'list'>
4
6 Likes

I have recently answered some other post with a similar question. But basically, the collate_fn receives a list of tuples if your __getitem__ function from a Dataset subclass returns a tuple, or just a normal list if your Dataset subclass returns only one element. Its main objective is to create your batch without spending much time implementing it manually. Try to see it as a glue that you specify the way examples stick together in a batch. If you don’t use it, PyTorch only put batch_size examples together as you would using torch.stack (not exactly it, but it is simple like that).

The following code I wrote on this post should help you grasp the real understanding. It pads sequences with 0 until the maximum sequence size of the batch, that is why I need the collate_fn, because a standard batching algorithm (simply using torch.stack) won’t work in my case, and I need to manually pad different sequences with variable length to the same size before creating the batch.

27 Likes

where is the parameter ‘batch’ come from?

1 Like

Batch parameter is implicit in the Dataloader function.
You just have to pass the collate function name to collate_fn Dataloader parameter.

Here’s an example:

loader_collate = DataLoader(
    dataset, shuffle=True, batch_size=5, collate_fn=collate_fn)

very cool example @Paulo_Mann , also possible to it with pytorch pad_sequences function ( which I am not sure was there in 2018:))

1 Like

Thanks, this helped me. However, I’m still wondering what gets passed in batch parameter of collate_fn?