I’m getting an error in my implementation:
class RandomStretch(object):
'''
Stretches an image's height or width based on some probability and scale
Params:
p_h: Probability to stretch height
p_w: Probability to stretch width
h_scale: Tuple of (height_lower_boundary, height_upper_boundary)
w_scale: Tuple of (width_lower_boundary, width_upper_boundary)
'''
def __init__(self, p_h, p_w, h_scale, w_scale):
assert p_h + p_w < 1.0
self.p_h = p_h
self.p_w = p_w
self.h_low, self.h_high = h_scale
self.w_low, self.w_high = w_scale
def __call__(self, sample):
image, label = sample
# Random float in [0, 1)
prob = np.random.random()
# Height stretch
if 0 < prob and prob < self.p_h:
h_stretch = np.random.uniform(self.h_low, self.h_high)
image = transforms.Resize(size=(720*h_stretch, 720))(image)
# Width stretch
elif self.p_h < prob and prob < self.p_w:
w_stretch = np.random.uniform(self.w_low, self.w_high)
image = transforms.Resize(size=(720, 720*w_stretch))(image)
return {'image': image, 'label:': label}
train_transforms = transforms.Compose([transforms.Resize(size=(720, 720)),
RandomStretch(p_h=0.25, p_w=0.25, h_scale=(0.9, 1.1), w_scale=(0.9, 1.1)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomAffine(degrees=0, scale=(0.9, 1.1)),
transforms.Grayscale(1),
transforms.ToTensor(),
transforms.Normalize(0.5, 0.5)])
train_data = ImageFolder(train_dir, transform=train_transforms, is_valid_file=check_valid)
train_loader = DataLoader(train_data, batch_size=64, shuffle=True, pin_memory=True, num_workers=4)
train_images, train_labels = next(iter(train_loader))
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-95-eda6aa29a3f9> in <module>
1 train_classes = dict(zip((0, 1, 2, 3, 4), train_data.classes))
----> 2 train_images, train_labels = next(iter(train_loader))
3
4 print('[Train Loader]\n')
5 print('images.shape: {} \ttype(images): {}'.format(train_images.shape, type(train_images)))
~/anaconda3/envs/cv/lib/python3.6/site-packages/torch/utils/data/dataloader.py in __next__(self)
361
362 def __next__(self):
--> 363 data = self._next_data()
364 self._num_yielded += 1
365 if self._dataset_kind == _DatasetKind.Iterable and \
~/anaconda3/envs/cv/lib/python3.6/site-packages/torch/utils/data/dataloader.py in _next_data(self)
987 else:
988 del self._task_info[idx]
--> 989 return self._process_data(data)
990
991 def _try_put_index(self):
~/anaconda3/envs/cv/lib/python3.6/site-packages/torch/utils/data/dataloader.py in _process_data(self, data)
1012 self._try_put_index()
1013 if isinstance(data, ExceptionWrapper):
-> 1014 data.reraise()
1015 return data
1016
~/anaconda3/envs/cv/lib/python3.6/site-packages/torch/_utils.py in reraise(self)
393 # (https://bugs.python.org/issue2651), so we work around it.
394 msg = KeyErrorMessage(msg)
--> 395 raise self.exc_type(msg)
TypeError: Caught TypeError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/home/wilson/anaconda3/envs/cv/lib/python3.6/site-packages/torch/utils/data/_utils/worker.py", line 185, in _worker_loop
data = fetcher.fetch(index)
File "/home/wilson/anaconda3/envs/cv/lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/home/wilson/anaconda3/envs/cv/lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/home/wilson/anaconda3/envs/cv/lib/python3.6/site-packages/torchvision/datasets/folder.py", line 139, in __getitem__
sample = self.transform(sample)
File "/home/wilson/anaconda3/envs/cv/lib/python3.6/site-packages/torchvision/transforms/transforms.py", line 61, in __call__
img = t(img)
File "<ipython-input-93-0b840063ddb7>", line 39, in __call__
image, label = sample
TypeError: 'Image' object is not iterable