Follow up question for Raw images using DatasetFolder
If I use transforms.ToTensor(), I get the following error:
img should be PIL Image. Got <class ‘torch.Tensor’>
If I used transforms.ToPILImage() I get the following error:
TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class ‘PIL.Image.Image’>
If I change the return type of __raw_loader to PIL image:
TypeError: pic should be Tensor or ndarray. Got <class ‘PIL.Image.Image’>.
Kind of going in circles here. What is the proper way to load a raw image using DatasetFolder?
Code:
import torch
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from torchvision.datasets import DatasetFolder
class RawFolder(DatasetFolder):
EXTENSIONS = ('.raw')
def __init__(self, root, transform=None, target_transform=None, loader=None):
super(RawFolder, self).__init__(root, self.__raw_loader,
self.EXTENSIONS,
transform=transform,
target_transform=target_transform)
@staticmethod
def __raw_loader(filename):
data = np.fromfile(filename, dtype='float32').reshape(80,180)
return Image.fromarray(data)
def __check_file(filename):
return os.path.isfile(filename) == True
data_transform = transforms.Compose([
transforms.ToPILImage(),
transforms.Grayscale(num_output_channels=1)
])
the_dataset = RawFolder(root=rootdir, transform=data_transform)
dataset_loader = torch.utils.data.DataLoader(the_dataset,
batch_size=64,
shuffle=False,
num_workers=0)
classes = ['1']
import numpy as np
# files already created, training and testing can start here.
# re-run first cell if necessary
def goshow(img):
img = img / 2 + 0.5 # unnormalize
plt.imshow(np.transpose(img, (1, 2, 0))) # convert from Tensor image
# Visualize a few images
dataiter = iter(dataset_loader)
images, labels = dataiter.next()
images = images.numpy() # convert images to numpy for display
# plot the images in the batch, along with the corresponding labels
fig = plt.figure(figsize=(25, 12))
# display some images
for idx in np.arange(25):
ax = fig.add_subplot(5, 5, idx+1, xticks=[], yticks=[])
goshow(images[idx])
ax.set_title(classes[labels[idx]])
Error for the last case mentioned above:
Errors always occur here: images, labels = dataiter.next()
TypeError Traceback (most recent call last)
<ipython-input-148-ffafba196b64> in <module>
55 # Visualize a few images
56 dataiter = iter(dataset_loader)
---> 57 images, labels = dataiter.next()
58 images = images.numpy() # convert images to numpy for display
59
~\anaconda3\lib\site-packages\torch\utils\data\dataloader.py in __next__(self)
344 def __next__(self):
345 index = self._next_index() # may raise StopIteration
--> 346 data = self._dataset_fetcher.fetch(index) # may raise StopIteration
347 if self._pin_memory:
348 data = _utils.pin_memory.pin_memory(data)
~\anaconda3\lib\site-packages\torch\utils\data\_utils\fetch.py in fetch(self, possibly_batched_index)
42 def fetch(self, possibly_batched_index):
43 if self.auto_collation:
---> 44 data = [self.dataset[idx] for idx in possibly_batched_index]
45 else:
46 data = self.dataset[possibly_batched_index]
~\anaconda3\lib\site-packages\torch\utils\data\_utils\fetch.py in <listcomp>(.0)
42 def fetch(self, possibly_batched_index):
43 if self.auto_collation:
---> 44 data = [self.dataset[idx] for idx in possibly_batched_index]
45 else:
46 data = self.dataset[possibly_batched_index]
~\anaconda3\lib\site-packages\torchvision\datasets\folder.py in __getitem__(self, index)
138 sample = self.loader(path)
139 if self.transform is not None:
--> 140 sample = self.transform(sample)
141 if self.target_transform is not None:
142 target = self.target_transform(target)
~\anaconda3\lib\site-packages\torchvision\transforms\transforms.py in __call__(self, img)
68 def __call__(self, img):
69 for t in self.transforms:
---> 70 img = t(img)
71 return img
72
~\anaconda3\lib\site-packages\torchvision\transforms\transforms.py in __call__(self, pic)
134
135 """
--> 136 return F.to_pil_image(pic, self.mode)
137
138 def __repr__(self):
~\anaconda3\lib\site-packages\torchvision\transforms\functional.py in to_pil_image(pic, mode)
117 """
118 if not(isinstance(pic, torch.Tensor) or isinstance(pic, np.ndarray)):
--> 119 raise TypeError('pic should be Tensor or ndarray. Got {}.'.format(type(pic)))
120
121 elif isinstance(pic, torch.Tensor):
TypeError: pic should be Tensor or ndarray. Got <class 'PIL.Image.Image'>.