Trouble using transforms.FiveCrop()/TenCrop()

Hello,

I am trying to increase my CNN’s performance and thus i decided to “play” with some transformations in order to see how they affect my model. I read that FiveCrop() and TenCrop() might help because they generate extra data to train on. However, when i try to train the model, using one of the transformations mentioned above, i get the following error:

TypeError: pic should be PIL Image or ndarray. Got <class ‘tuple’>

In the documentation of those transformations, it only states a note for the test procedure, any idea how to fix this?

Thanks in advance!

1 Like

Could you post the code snippet where you try to apply this transformation?
Apparently you are passing a tuple, while the method expects a PIL.Image.

trans = transforms.Compose([transforms.Resize(256),
                            transforms.TenCrop(224), # this is a list of PIL Images
                            transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])),
                            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
                           ])
test_trans = transforms.Compose([transforms.Resize(256),
                            transforms.CenterCrop(224),
                            #transforms.Lambda(lambda crops: torch.stack([ToTensor()(crop) for crop in crops])),
                            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
                           ])
train = datasets.ImageFolder('/data/dog_images/train', transform = trans)
test = datasets.ImageFolder('/data/dog_images/test', transform = test_trans)
valid = datasets.ImageFolder('/data/dog_images/valid', transform = test_trans)

trainloader = torch.utils.data.DataLoader(train, shuffle = True, batch_size = 16, drop_last = True)
testloader = torch.utils.data.DataLoader(test, shuffle = False, batch_size = 16, drop_last = True)
validloader = torch.utils.data.DataLoader(valid, shuffle = False, batch_size = 16, drop_last = True)

loaders_scratch = {'train': trainloader,
                  'valid': validloader,
                  'test': testloader}

So, it appears the tencrop part of the transforms is stopping the Normalize part of it.

trans = transforms.Compose([transforms.Resize(256),
                            transforms.TenCrop(224), # this is a list of PIL Images
                            transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])),
                            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
                           ])
test_trans = transforms.Compose([transforms.Resize(256),
                            transforms.CenterCrop(224),
                            #transforms.Lambda(lambda crops: torch.stack([ToTensor()(crop) for crop in crops])),
                            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
                           ])
train = datasets.ImageFolder('/data/dog_images/train', transform = trans)
test = datasets.ImageFolder('/data/dog_images/test', transform = test_trans)
valid = datasets.ImageFolder('/data/dog_images/valid', transform = test_trans)

trainloader = torch.utils.data.DataLoader(train, shuffle = True, batch_size = 16, drop_last = True)
testloader = torch.utils.data.DataLoader(test, shuffle = False, batch_size = 16, drop_last = True)
validloader = torch.utils.data.DataLoader(valid, shuffle = False, batch_size = 16, drop_last = True)

loaders_scratch = {'train': trainloader,
                  'valid': validloader,
                  'test': testloader}

Please help with a way around the transforms.Normalize as it is returning error

I assume the error is raised, since you are passing multiple inputs returned by the TenCrop transformation?
If that’s the case, you could add the Normalize operation to the transforms.Lambda call as:

transforms.Lambda(lambda crops: torch.stack([transforms.Normalize(mean=..., std=...)(transforms.ToTensor()(crop)) for crop in crops]))

or alternatively add another transforms.Lambda as:

trans = transforms.Compose([transforms.Resize(256),
                            transforms.TenCrop(224), # this is a list of PIL Images
                            transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])),
                            transforms.Lambda(lambda tensors:
                                torch.stack([transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(t) for t in tensors]))
                           ])
1 Like

Thank you @ptrblck . I am yet to try it but I’m positive it would work. Why I didn’t of this solution beats my imagination.

Thank you @ptrblck, it worked.

I have another question: Applying TenCrop transforms and using batch (say 16) loading creates a batch with the dimension torch.Size([16, 10, 3, 224, 224]). This means I have 160 images when I reshape it to BCWH, ten of each. I think will be to large a batch size to use and it will also make the batch distribution unrepresentative of the entire dataset.

Please help.
`

If the total batch size is too large, you could either decrease the batch_size or use FiveCrop.
FiveCrop and TenCrop will increase the samples in the batch, but that’s of course its purpose.

I am working with ChestXray data. I have found a model on Github.
Link GitHub - arnoweng/CheXNet: A pytorch reimplementation of CheXNet
I have downloaded the pre-trained model and loaded it in CUDA-based GPU. Now I want to test the code by using an image dataset. I believe the ChestXrayDataSet function is working properly. and I am using ChestXrayDataSet for test_dataset generation. Then I have used the DataLoader of PyTorch for creating the test_loader variable. Up to this everything works fine.
But when I’m using the enumerate in test_loader it is creating an error. I believe there is a problem with my lambda function in the test_dataset variable or the problem lies in the ChestXrayDataSet function. I have added the ChestXrayDataSet function and the other two images with errors!

I don’t know exactly what’s causing the error, but try to replace the lambda operations with custom transformations implemented as classes. You could use any torchvision.transforms as the boilerplate.

PS: you can post code snippets by wrapping them into three backticks ```, which makes debugging easier :wink: