Hi,
I’m using transforms.RandomHorizontalFlip for data augmentation. However it throws the following error.
File "/home/usr/Courses/Projects/Iceberg_id/load_data.py", line 129, in __getitem__
imgs = self.transform(imgs)
File "/home/usr/anaconda3/lib/python3.6/site-packages/torchvision-0.1.9-py3.6.egg/torchvision/transforms.py", line 34, in __call__
img = t(img)
File "/home/usr/anaconda3/lib/python3.6/site-packages/torchvision-0.1.9-py3.6.egg/torchvision/transforms.py", line 326, in __call__
return img.transpose(Image.FLIP_LEFT_RIGHT)
TypeError: transpose received an invalid combination of arguments - got (int), but expected (int dim0, int dim1)
My Transform is is follows:
txfm = transforms.Compose([Resize((80,80)), ToTensor(),\
transforms.RandomHorizontalFlip(),\
transforms.Normalize(mn,std)])
Here Resize and ToTensor were defined by me, the other two are from transforms library.
These transforms are implemented as follows:
class req_data(Dataset):
'''
Creates the dataset for the data
initiates the __len__ and __getitem__
'''
def __init__(self,images,labels, transform = None):
self.images = images
self.labels = labels
self.transform = transform
def __len__(self):
return len(self.labels)
def __getitem__(self,indx):
imgs = self.images[indx]
labls = self.labels[indx]
if self.transform:
imgs = self.transform(imgs)
return (imgs,labls)
# Create the transforms
class Resize(object):
'''
Inputs: a tuple of the form (h,w)
Returns: a resized image of the given dimensions
'''
def __init__(self,output_size):
self.output_size = output_size
def __call__(self,inpt):
image = inpt
new_h, new_w = self.output_size
image = image.transpose((1,2,0))
image = transform.resize(image,self.output_size)
image = image.transpose((2,0,1))
return image
class ToTensor(object):
'''Convert ndarrays in sample to Tensors
We don't need to make a tensor out of labels
as it is only 1 dim array
'''
def __call__(self, inpt):
image = inpt
return torch.from_numpy(image).float()
The transform in transform.resize(image,self.output_size) is from skimage, not PyTorch.
My PyTorch is installed from source.
Could someone please help me out?