Confusion matrix for semantic segmentation

(Neda) #1

After train the model, I am using this snippet to report the confusion matrix, score accuracy,… I am not sure am I doing correctly or the confusion matrix calculation should be inside the training loop.

from sklearn.metrics import confusion_matrix
from sklearn.metrics import accuracy_score 
from sklearn.metrics import classification_report 
from sklearn.metrics import roc_curve, auc # roc curve tools  

# confusion matrix          
    for phase in ['valid', 'test']:  
        with torch.no_grad():
            for t_image, mask, image_paths, target_paths in dataLoaders[phase]:
            # get the inputs
                t_image, mask =,
                output = model(t_image)
                probability_class = torch.exp(output) # logit or probability for two classes
                prediction = torch.argmax(probability_class,1) # predic for each class (torch.argmax Returns the indices of the maximum values of a tensor across a dimension)
                cm = confusion_matrix(prediction.view(-1), mask.view(-1))
                #print('{} {} cm:'.format(phase, cm))
                # performance metrics for the model
                score_accuracy = accuracy_score(prediction.view(-1), mask.view(-1))
                report = classification_report(prediction.view(-1), mask.view(-1)) 
                fpr, tpr, _ = roc_curve(prediction.view(-1), mask.view(-1))
                roc_auc = auc(fpr,tpr)

            print('{} score_accuracy: {:.4f}'.format(phase, score_accuracy))  
            print('{} Confusion Matrix: {}'.format(phase, cm))  
            plt.imshow(cm, interpolation='None',
            classNames = ['Negative','Positive']
            plt.title('Confusion Matrix')
            plt.ylabel('True class')
            plt.xlabel('Predicted class')
            tick_marks = numpy.arange(len(classNames))
            plt.xticks(tick_marks, classNames)
            plt.yticks(tick_marks, classNames)
            s = [['TN','FP'], ['FN', 'TP']]
            for i in range(2):
                for j in range(2):
                    plt.text(j,i, str(s[i][j])+" = "+str(cm[i][j]))                  
(lavenderxx) #2

Did you find the answer for this ? i am having the same problem