torch.Tensor to Dataset - torchvision.transforms

I would like to apply a transformation to a dataset composed of images.

The dataset is composed N images of size C x H x W , where C = 3, H = W = 256. The images are stored in a torch.Tensor().

I would like to apply certain transformation on each image, e.g. transforms.CenterCrop().

I was inspired by the TensorDataset() class found here https://github.com/pytorch/pytorch/blob/master/torch/utils/data/dataset.py

See the code below.

data = torch.Tensor(N, C, H, W) 

class TensorDataset(Dataset):
    def __init__(self, tensor):
        
        self.tensor = tensor
        self.center_crop = transforms.CenterCrop(100)

    def __getitem__(self, index):
    	out = self.tensor[index]
       	out = self.center_crop(out)
   
 	return out

    def __len__(self):
        return self.tensor[0].size(0)

When I run the following code:

data = TensorDataset(data)
# print the first image shape, I should get 3, 100, 100
print(d[0].shape)

I get the following error message :

Traceback (most recent call last):
  File "/Users/jphreid/Documents/GitHub/Amazon/amazon.py", line 75, in <module>
    print(d[0].shape)
  File "/Users/jphreid/Documents/GitHub/Amazon/amazon.py", line 58, in __getitem__
    out = self.center_crop(out)
  File "/anaconda3/lib/python3.6/site-packages/torchvision/transforms/transforms.py", line 215, in __call__
    return F.center_crop(img, self.size)
  File "/anaconda3/lib/python3.6/site-packages/torchvision/transforms/functional.py", line 305, in center_crop
    w, h = img.size
TypeError: 'builtin_function_or_method' object is not iterable
[Finished in 3.7s with exit code 1]

The code above works just well if I remove out = self.center_crop(out) from the __getitem__() instance.

Help would be appreciated.

The torchvision transformations work an PIL.Images.
You could therefore store or load images in your Dataset and after the cropping transform it to a tensor. Alternatively, if you already have the tensors, you could transform them back to an image, apply the transformation, and transform it back to a tensor.

import torchvision.transforms.functional as TF

...
def __getitem__(self, index):
    out = self.tensor[index]
    out = TF.to_pil_image(out)
    out = self.center_crop(out)
    out = TF.to_tensor(out)
    return out
...
2 Likes

Thanks - very much appreciated!