Hi,
I am not sure with what collate_fn does.
is there any example that helps understanding what it does?
You can use your own collate_fn
to process the list of samples to form a batch.
The batch
argument is a list with all your samples. E.g. if you would like to return variable-sized data, have a look at this thread.
so as ptrblck said the collate_fn
is your callable/function that processes the batch you want to return from your dataloader. e.g.
def collate_fn(batch):
print(type(batch))
print(len(batch))
in my case of batch_size=4
will return a list of size four. Lets check it:
<class 'list'>
4
I have recently answered some other post with a similar question. But basically, the collate_fn
receives a list of tuples if your __getitem__
function from a Dataset
subclass returns a tuple, or just a normal list if your Dataset
subclass returns only one element. Its main objective is to create your batch without spending much time implementing it manually. Try to see it as a glue that you specify the way examples stick together in a batch. If you don’t use it, PyTorch only put batch_size
examples together as you would using torch.stack
(not exactly it, but it is simple like that).
The following code I wrote on this post should help you grasp the real understanding. It pads sequences with 0 until the maximum sequence size of the batch, that is why I need the collate_fn
, because a standard batching algorithm (simply using torch.stack
) won’t work in my case, and I need to manually pad different sequences with variable length to the same size before creating the batch.
where is the parameter ‘batch’ come from?
Batch parameter is implicit in the Dataloader function.
You just have to pass the collate function name to collate_fn Dataloader parameter.
Here’s an example:
loader_collate = DataLoader(
dataset, shuffle=True, batch_size=5, collate_fn=collate_fn)
very cool example @Paulo_Mann , also possible to it with pytorch pad_sequences function ( which I am not sure was there in 2018:))
Thanks, this helped me. However, I’m still wondering what gets passed in batch
parameter of collate_fn
?
I am wondering when we call collate_fn
for batch in loader:
, it will call collate_fn
once or n times?