How to create a dataloader with variable-size input

By default, torch stacks the input image to from a tensor of size N*C*H*W, so every image in the batch must have the same height and width. In order to load a batch with variable size input image, we have to use our own collate_fn which is used to pack a batch of images.

For image classification, the input to collate_fn is a list of with size batch_size. Each element is a tuple where the first element is the input image(a torch.FloatTensor) and the second element is the image label which is simply an int. Because the samples in a batch have different size, we can store these samples in a list ans store the corresponding labels in torch.LongTensor. Then we put the image list and the label tensor into a list and return the result.

here is a very simple snippet to demonstrate how to write a custom collate_fn:

import torch
from torch.utils.data import DataLoader
from torchvision import transforms
import torchvision.datasets as datasets
import matplotlib.pyplot as plt

# a simple custom collate function, just to show the idea
def my_collate(batch):
    data = [item[0] for item in batch]
    target = [item[1] for item in batch]
    target = torch.LongTensor(target)
    return [data, target]


def show_image_batch(img_list, title=None):
    num = len(img_list)
    fig = plt.figure()
    for i in range(num):
        ax = fig.add_subplot(1, num, i+1)
        ax.imshow(img_list[i].numpy().transpose([1,2,0]))
        ax.set_title(title[i])

    plt.show()

#  do not do randomCrop to show that the custom collate_fn can handle images of different size
train_transforms = transforms.Compose([transforms.Scale(size = 224),
                                       transforms.ToTensor(),
                                       ])

# change root to valid dir in your system, see ImageFolder documentation for more info
train_dataset = datasets.ImageFolder(root="/hd1/jdhao/toyset",
                                     transform=train_transforms)

trainset = DataLoader(dataset=train_dataset,
                      batch_size=4,
                      shuffle=True,
                      collate_fn=my_collate, # use custom collate function here
                      pin_memory=True)

trainiter = iter(trainset)
imgs, labels = trainiter.next()

# print(type(imgs), type(labels))
show_image_batch(imgs, title=[train_dataset.classes[x] for x in labels])
62 Likes