By default, torch stacks the input image to from a tensor of size N*C*H*W
, so every image in the batch must have the same height and width. In order to load a batch with variable size input image, we have to use our own collate_fn
which is used to pack a batch of images.
For image classification, the input to collate_fn
is a list of with size batch_size
. Each element is a tuple where the first element is the input image(a torch.FloatTensor
) and the second element is the image label which is simply an int
. Because the samples in a batch have different size, we can store these samples in a list ans store the corresponding labels in torch.LongTensor
. Then we put the image list and the label tensor into a list and return the result.
here is a very simple snippet to demonstrate how to write a custom collate_fn
:
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
import torchvision.datasets as datasets
import matplotlib.pyplot as plt
# a simple custom collate function, just to show the idea
def my_collate(batch):
data = [item[0] for item in batch]
target = [item[1] for item in batch]
target = torch.LongTensor(target)
return [data, target]
def show_image_batch(img_list, title=None):
num = len(img_list)
fig = plt.figure()
for i in range(num):
ax = fig.add_subplot(1, num, i+1)
ax.imshow(img_list[i].numpy().transpose([1,2,0]))
ax.set_title(title[i])
plt.show()
# do not do randomCrop to show that the custom collate_fn can handle images of different size
train_transforms = transforms.Compose([transforms.Scale(size = 224),
transforms.ToTensor(),
])
# change root to valid dir in your system, see ImageFolder documentation for more info
train_dataset = datasets.ImageFolder(root="/hd1/jdhao/toyset",
transform=train_transforms)
trainset = DataLoader(dataset=train_dataset,
batch_size=4,
shuffle=True,
collate_fn=my_collate, # use custom collate function here
pin_memory=True)
trainiter = iter(trainset)
imgs, labels = trainiter.next()
# print(type(imgs), type(labels))
show_image_batch(imgs, title=[train_dataset.classes[x] for x in labels])