[1]
import torch
import torchvision
import torchvision.transforms as transforms
import torch.utils.data as data
import torchvision.datasets as datasets
import torch.nn as nn
[2]
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
[3]
train_data = datasets.CIFAR10(root = “./data”,
train = True,
download = True,
transform = transform)
train_data, val_data = torch.utils.data.random_split(train_data, [int(len(train_data) * 0.8), int(len(train_data)*0.2)])
test_data = datasets.CIFAR10(root = “./data”,
train = False,
download = True,
transform = transform)
[4]
classes = test_data.classes
dic_classes = {}
for i in range(len(classes)):
dic_classes[i] = classes[i]
print(dic_classes)
[5]
trainloader = torch.utils.data.DataLoader(train_data, batch_size=16,
shuffle=True)
valloader = torch.utils.data.DataLoader(val_data, batch_size=16,
shuffle=True)
testloader = torch.utils.data.DataLoader(test_data, batch_size=16,
shuffle=False)
[6]
import matplotlib.pyplot as plt
import numpy as np
def imshow(img, labels, dic):
num = len(labels)
rows = int(np.sqrt(num))
cols = int(np.sqrt(num))
fig = plt.figure(figsize=(20,20))
for i in range(rows*cols):
ax = fig.add_subplot(rows, cols, i+1)
tmp = img[i]
tmp = tmp / 2 + 0.5 # unnormalize
npimg = tmp.numpy()
ax.imshow(np.transpose(npimg, (1, 2, 0)), cmap = "bone")
ax.title.set_text(dic[labels[i].item()])
#plt.show()
ax.axis('off')
학습용 이미지를 무작위로 가져오기
dataiter = iter(trainloader)
images, labels = dataiter.next()
이미지 보여주기
imshow(images, labels, dic_classes)
정답(label) 출력
‘[6]’ is issued code. I hope this helps.
[1]~[5] all worked normally. The mentioned error occurs in [6].