My transforms fn:
def data_transforms(phase):
if phase == TRAIN:
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.RandomGrayscale(p=0.1),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
if phase == VAL:
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
if phase == TEST:
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
Dataset and Dataloaders fn:
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms(x))
for x in [TRAIN, VAL, TEST]}
dataloaders = {TRAIN: torch.utils.data.DataLoader(image_datasets[TRAIN], batch_size = 3, shuffle=True),
VAL: torch.utils.data.DataLoader(image_datasets[VAL], batch_size = 1, shuffle=False),
TEST: torch.utils.data.DataLoader(image_datasets[TEST], batch_size = 1, shuffle=False)}
but when I am trying to visualize it’s throwing a TypeError:
def imshow(inp, title=None):
inp = inp.numpy().transpose((1, 2, 0))
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
inp = std * inp + mean
inp = np.clip(inp, 0, 1)
plt.imshow(inp)
if title is not None:
plt.title(title)
plt.pause(0.001)
inputs, classes = next(iter(dataloaders[TRAIN])) # Error in this line
out = torchvision.utils.make_grid(inputs)
imshow(out, title=[class_names[x] for x in classes])
Error
TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'PIL.Image.Image'>