I would like to apply a transformation to a dataset composed of images.
The dataset is composed N
images of size C x H x W
, where C = 3, H = W = 256
. The images are stored in a torch.Tensor()
.
I would like to apply certain transformation on each image, e.g. transforms.CenterCrop()
.
I was inspired by the TensorDataset() class found here https://github.com/pytorch/pytorch/blob/master/torch/utils/data/dataset.py
See the code below.
data = torch.Tensor(N, C, H, W)
class TensorDataset(Dataset):
def __init__(self, tensor):
self.tensor = tensor
self.center_crop = transforms.CenterCrop(100)
def __getitem__(self, index):
out = self.tensor[index]
out = self.center_crop(out)
return out
def __len__(self):
return self.tensor[0].size(0)
When I run the following code:
data = TensorDataset(data)
# print the first image shape, I should get 3, 100, 100
print(d[0].shape)
I get the following error message :
Traceback (most recent call last):
File "/Users/jphreid/Documents/GitHub/Amazon/amazon.py", line 75, in <module>
print(d[0].shape)
File "/Users/jphreid/Documents/GitHub/Amazon/amazon.py", line 58, in __getitem__
out = self.center_crop(out)
File "/anaconda3/lib/python3.6/site-packages/torchvision/transforms/transforms.py", line 215, in __call__
return F.center_crop(img, self.size)
File "/anaconda3/lib/python3.6/site-packages/torchvision/transforms/functional.py", line 305, in center_crop
w, h = img.size
TypeError: 'builtin_function_or_method' object is not iterable
[Finished in 3.7s with exit code 1]
The code above works just well if I remove out = self.center_crop(out)
from the __getitem__()
instance.
Help would be appreciated.