I needed a dataset with channel extended images and bbox labels.
So I set a custom dataset and dataloader that was a inheritance of PyTorch Dataset and DataLoader.
Because the labels of my dataset has variable lengths I created a custom collate_fn function and applied it as ‘self.collate_fn’.
I think it works as I intended.
But when the ‘fetch’ step, an error occurs.
At the line ‘return self.collate_fn(data)’, that ‘self.collate_fn()’ is a list object already.
Why is that? I thought it would be a function object.
I’d like to know what makes that error.
Here’s my custom collate_fn.
No problem with these return values? Thanks for your help.
def collate_fn_ju0(dataset):
print('collating..')
images = [d[0] for d in dataset]
labels = [d[1] for d in dataset]
print('converting ndarray to tensor..\n')
for i in range(dataset.total_num_samples):
images[i] = torch.from_numpy(images[i])
labels[i] = torch.from_numpy(labels[i])
# images_converted = torch.stack(images, 0).permute(0, 3, 1, 2)
images_converted = torch.stack(images, 0).contiguous().permute(0, 3, 1, 2)
max_num_labels = max(label.shape[0] for label in labels)
labels_padded = torch.zeros(len(labels), max_num_labels, 5)
if max_num_labels > 0:
# label_padded = torch.ones((len(labels), max_num_labels, 5)) * -1
label_padded = torch.zeros((len(labels), max_num_labels, 5)) * -1
if max_num_labels > 0:
for idx, label in enumerate(labels):
# print(idx, label.shape[0])
if label.shape[0] > 0:
label_padded[idx, :label.shape[0], :] = label
else:
# label_padded = torch.ones((len(labels), 1, 5)) * -1
label_padded = torch.zeros((len(labels), 1, 5)) * -1
return images_converted, labels_padded