Conditional transforms for image resize

I only want to resize images that are smaller than my desired input size. Is there way to reshape images that are smaller than a certain size and ignore all others?

Using the Dataset class, you can lazily load your images in the __getitem__ method.
Here is a small example:

class TrainDataset(Dataset):
    def __init__(self, image_paths, targets):
        self.image_paths = image_paths
        self.targets = targets

    def __getitem__(self, index):
        image =[index])
        y =[index]

        # Resize your image here
        if image.size...:
        .... # Do other stuff here

        x = torch.from_numpy(image)
        return x, y

    def __len__(self):
        return len(self.image_paths)

You can wrap your Dataset to a DataLoader for efficient loading, preprocessing, shuffling etc.

Hope this helps!


For what it’s worth, the alternative to this is to subclass Resize. The advantage of this approach is that existing torchvision Dataset objects (like for CIFAR-10 or ImageNet) need not be modified - rather, a new transform can simply be used.

My implementation:

class ConditionalResize(transforms.Resize):
    Resize transform but only if the input is smaller than the resize dims

    def resize_height(self):
        return self.size if isinstance(self.size, int) else self.size[0]

    def resize_width(self):
        return self.size if isinstance(self.size, int) else self.size[1]

    def forward(self, img):
            img (PIL Image or Tensor): Image to be scaled.

            PIL Image or Tensor: Rescaled image.
        r_h, r_w = self.resize_height, self.resize_width
        if isinstance(img, torch.Tensor):
            h, w = img.shape[1:]
        else:  # PIL Image
            w, h = img.size
        if w < r_w or h < r_h:
            return super().forward(img)
            return img
1 Like