I am not sure with what collate_fn does.
is there any example that helps understanding what it does?
You can use your own
collate_fn to process the list of samples to form a batch.
batch argument is a list with all your samples. E.g. if you would like to return variable-sized data, have a look at this thread.
so as ptrblck said the
collate_fn is your callable/function that processes the batch you want to return from your dataloader. e.g.
def collate_fn(batch): print(type(batch)) print(len(batch))
in my case of
batch_size=4 will return a list of size four. Lets check it:
<class 'list'> 4
I have recently answered some other post with a similar question. But basically, the
collate_fn receives a list of tuples if your
__getitem__ function from a
Dataset subclass returns a tuple, or just a normal list if your
Dataset subclass returns only one element. Its main objective is to create your batch without spending much time implementing it manually. Try to see it as a glue that you specify the way examples stick together in a batch. If you don’t use it, PyTorch only put
batch_size examples together as you would using
torch.stack (not exactly it, but it is simple like that).
The following code I wrote on this post should help you grasp the real understanding. It pads sequences with 0 until the maximum sequence size of the batch, that is why I need the
collate_fn, because a standard batching algorithm (simply using
torch.stack) won’t work in my case, and I need to manually pad different sequences with variable length to the same size before creating the batch.
where is the parameter ‘batch’ come from?
Batch parameter is implicit in the Dataloader function.
You just have to pass the collate function name to collate_fn Dataloader parameter.
Here’s an example:
loader_collate = DataLoader( dataset, shuffle=True, batch_size=5, collate_fn=collate_fn)
very cool example @Paulo_Mann , also possible to it with pytorch pad_sequences function ( which I am not sure was there in 2018:))