I only want to resize images that are smaller than my desired input size. Is there way to reshape images that are smaller than a certain size and ignore all others?
Using the Dataset class, you can lazily load your images in the __getitem__
method.
Here is a small example:
class TrainDataset(Dataset):
def __init__(self, image_paths, targets):
self.image_paths = image_paths
self.targets = targets
def __getitem__(self, index):
image = Image.open(self.image_paths[index])
y = self.target[index]
# Resize your image here
if image.size...:
....
.... # Do other stuff here
x = torch.from_numpy(image)
return x, y
def __len__(self):
return len(self.image_paths)
You can wrap your Dataset
to a DataLoader for efficient loading, preprocessing, shuffling etc.
Hope this helps!
3 Likes
For what it’s worth, the alternative to this is to subclass Resize
. The advantage of this approach is that existing torchvision Dataset
objects (like for CIFAR-10 or ImageNet) need not be modified - rather, a new transform can simply be used.
My implementation:
class ConditionalResize(transforms.Resize):
"""
Resize transform but only if the input is smaller than the resize dims
"""
@property
def resize_height(self):
return self.size if isinstance(self.size, int) else self.size[0]
@property
def resize_width(self):
return self.size if isinstance(self.size, int) else self.size[1]
def forward(self, img):
"""
Args:
img (PIL Image or Tensor): Image to be scaled.
Returns:
PIL Image or Tensor: Rescaled image.
"""
r_h, r_w = self.resize_height, self.resize_width
if isinstance(img, torch.Tensor):
h, w = img.shape[1:]
else: # PIL Image
w, h = img.size
if w < r_w or h < r_h:
return super().forward(img)
else:
return img
1 Like