How to see prediction images at testing?

At this link at the end, there are some image and it’s prediction. How can you got this?

And as I can see, you don’t have any model class defined.
I can’t load my saved models with

model = torch.load(PATH)

because it require some model class, but I followed this tutorial and I don’t see any class.

@ptrblck_de I saw you are answering a lot and I’m apologize that I mention you, but I really need help, ty.

The tutorial mentions the full source code at the end which you can download from here and which should also include the code to visualize the predictions.

Yes, I downloaded that and work in this template, just put my own dataset and modificate train_one_epoch() function.

There is no code for visualize the predictions.

If you cannot find the corresponding code snippet check the linked Colab notebook which shows the visualization.

I tired follow this also Google Colab and still nothing happens. I don’t know what I’m doing wrong?

Here is my main function: (the get_model_instance_segmentation and get_transform functions I wrote above main in the script, and also defined class Moj_Dataset_ArT at the beggining in my script)

def main():
    # train on the GPU or on the CPU, if a GPU is not available
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

    # our dataset has two classes only - background and person
    num_classes = 2
    # use our dataset and defined transformations
    dataset = Moj_Dataset_ArT('Train/ArT', get_transform(train=True))
    dataset_val = Moj_Dataset_ArT('Train/ArT', get_transform(train=False))
    dataset_test = Moj_Dataset_ArT('Train/ArT', get_transform(train=False))

    # split the dataset in train and test set
    indices = torch.randperm(len(dataset)).tolist()
    dataset =, indices[:-1159])
    #print(len(dataset)) -> 2400
    dataset_val =, indices[2400:-559]) 
    #print(len(dataset_val)) -> 600
    dataset_test =, indices[-559:])
    #print(len(dataset_test)) -> 559

    # define training and validation data loaders
    data_loader =
        dataset, batch_size=4, shuffle=True, num_workers=4,

    data_loader_val =
        dataset, batch_size=4, shuffle=False, num_workers=4,

    data_loader_test =
        dataset_test, batch_size=1, shuffle=False, num_workers=4,

    # get the model using our helper function
    model = get_model_instance_segmentation(num_classes)

    # move model to the right device

    # construct an optimizer
    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.SGD(params, lr=0.001,
                                momentum=0.9, weight_decay=0.0005)
    # and a learning rate scheduler
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                   gamma=0.1) # gamma bila na 0.5 

    # let's train it for 10 epochs
    num_epochs = 2
    #PATH = '/home/Nezz/Train/ArT/TorchScript_format/'
    #model_scripted = torch.jit.script(model) # Export to TorchScript
    ml =[] 
    vl = []
    for epoch in range(num_epochs):
        # train for one epoch, printing every 10 iterations
        #loss_value =  train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=10)
        train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=10)
        # update the learning rate
        # evaluate on the test dataset
        evaluate(model, data_loader_test, device=device)
        #val_loss  = evaluate_loss(model, data_loader_val, device=device) 
        #vl.append(val_loss), os.path.join(PATH, 'epoch-{}.pt'.format(epoch)))
    #print("train_loss", ml) 
    #print("val_loss", vl)

    # pick one image from the test set
    img, _ = dataset[20]
    # put the model in evaluation mode
    with torch.no_grad():
        prediction = model([])
    Image.fromarray(img.mul(255).permute(1, 2, 0).byte().numpy())
    Image.fromarray(prediction[0]['masks'][0, 0].mul(255).byte().cpu().numpy())

    #print("That's it!")
if __name__ == "__main__":

Image.fromarray creates a PIL.Image which might be visualized by default in their notebook (or if executed in a standalone cell). Call .show() on the images and it should work.

1 Like

Yeaaah, it wokrs. Thank you a lot man :slight_smile: