I am getting this error while visualizing the data
This is how I have loaded the data.
num_classes = 2
batch_size = 13
net=CNN(out_1=13, out_2=32)
DATA_PATH_TRAIN ="F:/project/d2"
DATA_PATH_TEST ="F:/project/d2"
trans = transforms.Compose([
transforms.ToTensor(),
#transforms.Resize((100, 100))
])
train_dataset = datasets.ImageFolder(root=DATA_PATH_TRAIN, transform=trans)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
validation_dataset=train_dataset
validation_loader = DataLoader(dataset=validation_dataset, batch_size=batch_size, shuffle=True)
criterion = nn.CrossEntropyLoss()
learning_rate = 0.001
optimizer = torch.optim.Adam(net.parameters(), lr = learning_rate)
Function for Visualizing
class_names = ["not a face","face"]
def visualize_model(net, num_images=6):
was_training = net.training
net.eval()
images_so_far = 0
fig = plt.figure()
with torch.no_grad():
for i, (inputs, labels) in enumerate(validation_loader):
outputs = net(inputs)
_, preds = torch.max(outputs, 1)
for j in range(inputs.size()[0]):
images_so_far += 1
ax = plt.subplot(num_images//2, 2, images_so_far)
ax.axis('off')
ax.set_title('predicted: {}'.format(class_names[preds[j]]))
imshow(inputs.data[j])
if images_so_far == num_images:
net.train(mode=was_training)
return
net.train(mode=was_training)
my function for imshow
def imshow(img):
#img = img / 2 + 0.5
plt.imshow(np.transpose(img[0].numpy(), (1, 2, 0)))
plt.show()
def main():
dataiter = iter(train_loader)
images, labels = dataiter.next()
imshow(images)
print(labels)
if __name__ == "__main__":
main()
I am getting proper output for this imshow after loading my data
as shown in the image below…
I am also getting label for that particular image…
ptrblck
February 16, 2019, 9:56am
2
I’m not sure what the question is, as your code seems to be running fine.
This is the error I am getting when I run
visualize_model(net)
Error message
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-43-7ec848dbaa28> in <module>()
1
----> 2 visualize_model(net)
<ipython-input-41-3b3aaf3229b1> in visualize_model(net, num_images)
18 ax.axis('off')
19 ax.set_title('predicted: {}'.format(class_names[preds[j]]))
---> 20 imshow(inputs.data[j])
21
22 if images_so_far == num_images:
<ipython-input-40-315a2e968437> in imshow(img)
1 def imshow(img):
2 #img = img / 2 + 0.5
----> 3 plt.imshow(np.transpose(img[0].numpy(), (1, 2, 0)))
4 plt.show()
5
C:\ProgramData\Anaconda3\lib\site-packages\numpy\core\fromnumeric.py in transpose(a, axes)
596
597 """
--> 598 return _wrapfunc(a, 'transpose', axes)
599
600
C:\ProgramData\Anaconda3\lib\site-packages\numpy\core\fromnumeric.py in _wrapfunc(obj, method, *args, **kwds)
49 def _wrapfunc(obj, method, *args, **kwds):
50 try:
---> 51 return getattr(obj, method)(*args, **kwds)
52
53 # An AttributeError occurs if the object does not have
ValueError: axes don't match array
ptrblck
February 16, 2019, 10:07am
4
Thanks for the error message.
It looks like you are indexing your images
twice at dim0
.
In your eval loop, you are calling imshow(inputs.data[j])
, which gets the jth image in the current batch, and then again in imshow
you are calling img[0]
.
It looks like the indexing in imshow
could be removed.
Could you check that and see, if it’s running?
If I remove the indexing in the imshow in defintion then I still get the same error. And if I remove indexing in the visualize_model function then I get the following output…
ptrblck
February 16, 2019, 10:33am
6
Did you run the imshow
code again?
This code works for me:
def imshow(img):
plt.imshow(np.transpose(img.numpy(), (1, 2, 0)))
plt.show()
dataset = datasets.FakeData(transform=transforms.ToTensor())
image, target = dataset[0]
print(image.shape)
imshow(image)
loader = DataLoader(
dataset,
batch_size=2
)
inputs, _ = next(iter(loader))
for j in range(inputs.size()[0]):
print(inputs.shape)
imshow(inputs.data[j])
1 Like
On running this function as imshow this is working perfectlyy… thank you…
can you explain what does the following line do
plt.imshow(np.transpose(img.numpy(), (1, 2, 0)))