Show the classified images after the Network being trained?

Hello everyone.
Is it possible to generate the result of the trained network, i.e show the classified images after training on pytorch
Any help much appreciated
Many thanks

Would you like to classify and visualize unseen test images using your trained model?
If so, the Training a Classifier Tutorial might give you some ideas to write the code. In the tutorial, images will also be validated, visualized, and the corresponding class printed.
Let me know, if I misunderstood your use case.

1 Like

Thanks a lot, Ptrbck , yeah what I mean is that after training the network I would like to visualize the result of the classification (Validation), the images look like this
its kinda pixel-wise classification thus i want to regenerate these results found in some paper :

It looks like you are dealing with a segmentation task.
So I assume your output would have the shape [batch_size, nb_classes, height, width], is that correct?
If so, you can get the predicted class index using pred = torch.argmax(output, 1).
Using the pred tensor, you should be able to create these color images using a mapping from class index to color code.
If you don’t have the mapping yet, I would suggest to create a dict with some random mapping to your desired colors.

1 Like

Indeed, it is a segmentation task I do have the groundtruth, is it possible to visualize the output then?

Yes, as I said you can get the predictions using torch.argmax and try to visualize it directly or using a color mapping scheme:

output = torch.randn(1, 10, 24, 24)
pred = torch.argmax(output, 1)
1 Like

Can you tell how did you plot this image in 3D like that? :smiley:

where should I add it inside which part, :shushing_face: got confused Ptrblck,


you can use Matlab to plot it, its a special type of image called “Hyperspectral”, and these colored images are only the labeled data.

You could add it directly after output = model(data).
The better approach would probably be to store all predictions in a list and plot some images after the test loop finshes:

preds = []
for data, target in test_loader:
    output = model(data)
    pred = torch.argmax(output.detach(), 1)

nb_imgs_to_plot = 4
fig, axarr = plt.subplots(4)
for idx, img in enumerate(preds[:nb_imgs_to_plot]):
1 Like

Hello, in this tutorial the images appear blurry. Is there a way to fix that?

CIFAR10 images have a spatial resolution of 32x32 pixels, so they are kind of blurry unfortunately. :confused:
You could try to interpolate them to a higher resolution, but I’m not sure if that’s really useful.

1 Like