I’m pre-processing my data to feed into a CNN and am applying `RandomApply` with a probability of 0.3.

Is there a way to apply a transformation if the transformation in `RandomApply` didn’t get selected?
In other words, how do I apply another transformation if the other `70%` was selected?
(Kind of like if-else)

``````transforms.Compose([transforms.RandomApply(
[transforms.Resize((720+50, 720+50))], p=0.3),
transforms.ToTensor(),
transforms.Normalize(0.5, 0.5)])
``````

I think the cleanest approach would be to either define a custom transformation and add the logic to the `__call__` method in your object or to use e.g. `transforms.Lambda` and pick the transformation based on the random number.

1 Like

@ptrblck Thank you for the suggestions! I ended up creating a custom transformation and added the logic in the `__call__` method.

I’m getting an error in my implementation:

``````class RandomStretch(object):
'''
Stretches an image's height or width based on some probability and scale

Params:
p_h:      Probability to stretch height
p_w:      Probability to stretch width
h_scale:  Tuple of (height_lower_boundary, height_upper_boundary)
w_scale:  Tuple of (width_lower_boundary, width_upper_boundary)
'''

def __init__(self, p_h, p_w, h_scale, w_scale):
assert p_h + p_w < 1.0
self.p_h = p_h
self.p_w = p_w
self.h_low, self.h_high = h_scale
self.w_low, self.w_high = w_scale

def __call__(self, sample):

image, label = sample

# Random float in [0, 1)
prob = np.random.random()

# Height stretch
if 0 < prob and prob < self.p_h:
h_stretch = np.random.uniform(self.h_low, self.h_high)
image = transforms.Resize(size=(720*h_stretch, 720))(image)

# Width stretch
elif self.p_h < prob and prob < self.p_w:
w_stretch = np.random.uniform(self.w_low, self.w_high)
image = transforms.Resize(size=(720, 720*w_stretch))(image)

return {'image': image, 'label:': label}

train_transforms = transforms.Compose([transforms.Resize(size=(720, 720)),
RandomStretch(p_h=0.25, p_w=0.25, h_scale=(0.9, 1.1), w_scale=(0.9, 1.1)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomAffine(degrees=0, scale=(0.9, 1.1)),
transforms.Grayscale(1),
transforms.ToTensor(),
transforms.Normalize(0.5, 0.5)])

train_data = ImageFolder(train_dir, transform=train_transforms, is_valid_file=check_valid)

``````
``````---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-95-eda6aa29a3f9> in <module>
1 train_classes = dict(zip((0, 1, 2, 3, 4), train_data.classes))
----> 2 train_images, train_labels = next(iter(train_loader))
3
5 print('images.shape: {} \ttype(images): {}'.format(train_images.shape, type(train_images)))

361
362     def __next__(self):
--> 363         data = self._next_data()
364         self._num_yielded += 1
365         if self._dataset_kind == _DatasetKind.Iterable and \

987             else:
--> 989                 return self._process_data(data)
990
991     def _try_put_index(self):

1012         self._try_put_index()
1013         if isinstance(data, ExceptionWrapper):
-> 1014             data.reraise()
1015         return data
1016

~/anaconda3/envs/cv/lib/python3.6/site-packages/torch/_utils.py in reraise(self)
393             # (https://bugs.python.org/issue2651), so we work around it.
394             msg = KeyErrorMessage(msg)
--> 395         raise self.exc_type(msg)

TypeError: Caught TypeError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/home/wilson/anaconda3/envs/cv/lib/python3.6/site-packages/torch/utils/data/_utils/worker.py", line 185, in _worker_loop
data = fetcher.fetch(index)
File "/home/wilson/anaconda3/envs/cv/lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/home/wilson/anaconda3/envs/cv/lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/home/wilson/anaconda3/envs/cv/lib/python3.6/site-packages/torchvision/datasets/folder.py", line 139, in __getitem__
sample = self.transform(sample)
File "/home/wilson/anaconda3/envs/cv/lib/python3.6/site-packages/torchvision/transforms/transforms.py", line 61, in __call__
img = t(img)
File "<ipython-input-93-0b840063ddb7>", line 39, in __call__
image, label = sample
TypeError: 'Image' object is not iterable
``````

`torchvision.transforms` expect the `PIL.Image` as an input and also return the image, while your custom transformation returns a `dict` and tries to unwrap the incoming image so you won’t be able to use this transformation directly in `transforms.Compose`.
You could apply the `torchvision.transformations` on the image and target separately and use `RandomStretch` on both of them in another custom class.