I have recently answered some other post with a similar question. But basically, the collate_fn
receives a list of tuples if your __getitem__
function from a Dataset
subclass returns a tuple, or just a normal list if your Dataset
subclass returns only one element. Its main objective is to create your batch without spending much time implementing it manually. Try to see it as a glue that you specify the way examples stick together in a batch. If you don’t use it, PyTorch only put batch_size
examples together as you would using torch.stack
(not exactly it, but it is simple like that).
The following code I wrote on this post should help you grasp the real understanding. It pads sequences with 0 until the maximum sequence size of the batch, that is why I need the collate_fn
, because a standard batching algorithm (simply using torch.stack
) won’t work in my case, and I need to manually pad different sequences with variable length to the same size before creating the batch.