I'm using the custom collate_fn function, but there was a problem in fetch step

I needed a dataset with channel extended images and bbox labels.
So I set a custom dataset and dataloader that was a inheritance of PyTorch Dataset and DataLoader.
Because the labels of my dataset has variable lengths I created a custom collate_fn function and applied it as ‘self.collate_fn’.
I think it works as I intended.

But when the ‘fetch’ step, an error occurs.

At the line ‘return self.collate_fn(data)’, that ‘self.collate_fn()’ is a list object already.
Why is that? I thought it would be a function object.
I’d like to know what makes that error.

Here’s my custom collate_fn.
No problem with these return values? Thanks for your help.

def collate_fn_ju0(dataset):

    print('collating..')
    images = [d[0] for d in dataset]
    labels = [d[1] for d in dataset]

    print('converting ndarray to tensor..\n')
    for i in range(dataset.total_num_samples):
        images[i] = torch.from_numpy(images[i])
        labels[i] = torch.from_numpy(labels[i])

    # images_converted = torch.stack(images, 0).permute(0, 3, 1, 2)
    images_converted = torch.stack(images, 0).contiguous().permute(0, 3, 1, 2)

    max_num_labels = max(label.shape[0] for label in labels)
    labels_padded = torch.zeros(len(labels), max_num_labels, 5)
        
    if max_num_labels > 0:

        # label_padded = torch.ones((len(labels), max_num_labels, 5)) * -1
        label_padded = torch.zeros((len(labels), max_num_labels, 5)) * -1

        if max_num_labels > 0:
            for idx, label in enumerate(labels):
                # print(idx, label.shape[0])
                if label.shape[0] > 0:
                    label_padded[idx, :label.shape[0], :] = label
    else:
        # label_padded = torch.ones((len(labels), 1, 5)) * -1
        label_padded = torch.zeros((len(labels), 1, 5)) * -1

    return images_converted, labels_padded

I don’t fully understand your use case as it seems your custom collate_fn is iterating the entire dataset and “concerts” the data and targets, while its original purpose is to create a batch (or batch-like object) from the already processed items returned by Dataset.__getitem__.
Could you explain your use case a bit more and how you are using this collate_fn?

1 Like

Thanks for replying.
I found out what the problem is.
It should be called collate_fn, but collate_fn(data) was the cause.