Hello everyone.
Is it possible to generate the result of the trained network, i.e show the classified images after training on pytorch
Any help much appreciated
Many thanks
Would you like to classify and visualize unseen test images using your trained model?
If so, the Training a Classifier Tutorial might give you some ideas to write the code. In the tutorial, images will also be validated, visualized, and the corresponding class printed.
Let me know, if I misunderstood your use case.
Thanks a lot, Ptrbck , yeah what I mean is that after training the network I would like to visualize the result of the classification (Validation), the images look like this
its kinda pixel-wise classification thus i want to regenerate these results found in some paper :
It looks like you are dealing with a segmentation task.
So I assume your output would have the shape [batch_size, nb_classes, height, width]
, is that correct?
If so, you can get the predicted class index using pred = torch.argmax(output, 1)
.
Using the pred
tensor, you should be able to create these color images using a mapping from class index to color code.
If you don’t have the mapping yet, I would suggest to create a dict
with some random mapping to your desired colors.
Indeed, it is a segmentation task I do have the groundtruth, is it possible to visualize the output then?
Yes, as I said you can get the predictions using torch.argmax
and try to visualize it directly or using a color mapping scheme:
output = torch.randn(1, 10, 24, 24)
pred = torch.argmax(output, 1)
plt.imshow(pred[0])
Can you tell how did you plot this image in 3D like that?
you can use Matlab to plot it, its a special type of image called “Hyperspectral”, and these colored images are only the labeled data.
You could add it directly after output = model(data)
.
The better approach would probably be to store all predictions in a list and plot some images after the test loop finshes:
preds = []
for data, target in test_loader:
...
output = model(data)
pred = torch.argmax(output.detach(), 1)
preds.append(pred)
...
nb_imgs_to_plot = 4
fig, axarr = plt.subplots(4)
for idx, img in enumerate(preds[:nb_imgs_to_plot]):
axarr[idx].imshow(img.numpy())
Hello, in this tutorial the images appear blurry. Is there a way to fix that?
CIFAR10 images have a spatial resolution of 32x32
pixels, so they are kind of blurry unfortunately.
You could try to interpolate them to a higher resolution, but I’m not sure if that’s really useful.