Data Loading & Transformations: transforms.RandomApply() alternative transformation

I’m pre-processing my data to feed into a CNN and am applying RandomApply with a probability of 0.3.

Is there a way to apply a transformation if the transformation in RandomApply didn’t get selected?
In other words, how do I apply another transformation if the other 70% was selected?
(Kind of like if-else)

                        [transforms.Resize((720+50, 720+50))], p=0.3),
                    transforms.Normalize(0.5, 0.5)])

I think the cleanest approach would be to either define a custom transformation and add the logic to the __call__ method in your object or to use e.g. transforms.Lambda and pick the transformation based on the random number.

1 Like

@ptrblck Thank you for the suggestions! I ended up creating a custom transformation and added the logic in the __call__ method.

I’m getting an error in my implementation:

class RandomStretch(object):
    Stretches an image's height or width based on some probability and scale
        p_h:      Probability to stretch height
        p_w:      Probability to stretch width
        h_scale:  Tuple of (height_lower_boundary, height_upper_boundary)
        w_scale:  Tuple of (width_lower_boundary, width_upper_boundary)
    def __init__(self, p_h, p_w, h_scale, w_scale):
        assert p_h + p_w < 1.0
        self.p_h = p_h
        self.p_w = p_w
        self.h_low, self.h_high = h_scale
        self.w_low, self.w_high = w_scale
    def __call__(self, sample):
        image, label = sample
        # Random float in [0, 1)
        prob = np.random.random()
        # Height stretch
        if 0 < prob and prob < self.p_h:
            h_stretch = np.random.uniform(self.h_low, self.h_high)
            image = transforms.Resize(size=(720*h_stretch, 720))(image)
        # Width stretch
        elif self.p_h < prob and prob < self.p_w:
            w_stretch = np.random.uniform(self.w_low, self.w_high)
            image = transforms.Resize(size=(720, 720*w_stretch))(image)
        return {'image': image, 'label:': label}

train_transforms = transforms.Compose([transforms.Resize(size=(720, 720)),
                                       RandomStretch(p_h=0.25, p_w=0.25, h_scale=(0.9, 1.1), w_scale=(0.9, 1.1)),
                                       transforms.RandomAffine(degrees=0, scale=(0.9, 1.1)),
                                       transforms.Normalize(0.5, 0.5)]) 

train_data = ImageFolder(train_dir, transform=train_transforms, is_valid_file=check_valid)
train_loader = DataLoader(train_data, batch_size=64, shuffle=True, pin_memory=True, num_workers=4)

train_images, train_labels = next(iter(train_loader))
TypeError                                 Traceback (most recent call last)
<ipython-input-95-eda6aa29a3f9> in <module>
      1 train_classes = dict(zip((0, 1, 2, 3, 4), train_data.classes))
----> 2 train_images, train_labels = next(iter(train_loader))
      4 print('[Train Loader]\n')
      5 print('images.shape: {} \ttype(images): {}'.format(train_images.shape, type(train_images)))

~/anaconda3/envs/cv/lib/python3.6/site-packages/torch/utils/data/ in __next__(self)
    362     def __next__(self):
--> 363         data = self._next_data()
    364         self._num_yielded += 1
    365         if self._dataset_kind == _DatasetKind.Iterable and \

~/anaconda3/envs/cv/lib/python3.6/site-packages/torch/utils/data/ in _next_data(self)
    987             else:
    988                 del self._task_info[idx]
--> 989                 return self._process_data(data)
    991     def _try_put_index(self):

~/anaconda3/envs/cv/lib/python3.6/site-packages/torch/utils/data/ in _process_data(self, data)
   1012         self._try_put_index()
   1013         if isinstance(data, ExceptionWrapper):
-> 1014             data.reraise()
   1015         return data

~/anaconda3/envs/cv/lib/python3.6/site-packages/torch/ in reraise(self)
    393             # (, so we work around it.
    394             msg = KeyErrorMessage(msg)
--> 395         raise self.exc_type(msg)

TypeError: Caught TypeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/wilson/anaconda3/envs/cv/lib/python3.6/site-packages/torch/utils/data/_utils/", line 185, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/wilson/anaconda3/envs/cv/lib/python3.6/site-packages/torch/utils/data/_utils/", line 44, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/wilson/anaconda3/envs/cv/lib/python3.6/site-packages/torch/utils/data/_utils/", line 44, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/wilson/anaconda3/envs/cv/lib/python3.6/site-packages/torchvision/datasets/", line 139, in __getitem__
    sample = self.transform(sample)
  File "/home/wilson/anaconda3/envs/cv/lib/python3.6/site-packages/torchvision/transforms/", line 61, in __call__
    img = t(img)
  File "<ipython-input-93-0b840063ddb7>", line 39, in __call__
    image, label = sample
TypeError: 'Image' object is not iterable

torchvision.transforms expect the PIL.Image as an input and also return the image, while your custom transformation returns a dict and tries to unwrap the incoming image so you won’t be able to use this transformation directly in transforms.Compose.
You could apply the torchvision.transformations on the image and target separately and use RandomStretch on both of them in another custom class.