Implementation of IOU for multiclass semantic segmentation

Hello! I would want to compute the training accuracy of my neural network, but I want it to be IOU. I found something on the internet but mostly are for binary classification only, can anyone help me?

by the way this is my script so far:

def display_segmentation():

  model = ESNet(classes=6)
  path = F"/content/drive/MyDrive/Thesis_QuilangCastillano/Models/ESNET_5250.37-checkpoint.pt"  
  checkpoint = torch.load(path)
  model.load_state_dict(checkpoint['state_dict'])

  total = 0
  correct = 0

  torch.set_printoptions(edgeitems=128)

  with torch.no_grad():
    model.eval()
    for i in tqdm(range(1)):
      net_out = model(img[100].view(-1, 3, 384, 512))
      segmentation_label = torch.argmax(net_out.squeeze(), dim= 0)
      segmented_image = decode_segmap(segmentation_label)

     
      label = np.unique(segmentation_label)
      groundtruth = np.unique(mask[100])


      alpha = 0.3 # how much transparency to apply
      beta = 1 - alpha # alpha + beta should equal 1
      gamma = 0 # scalar added to each sum

      image = np.array(img[100].view(384, 512, 3))
      prediction = cv2.addWeighted(segmented_image, alpha, image, beta, gamma, image,  dtype=cv2.CV_64F)

      Sub = f"Prediction: {label}, Groundtruth: {groundtruth}"

      plt.imshow(prediction)
      plt.title(Sub)
      plt.show()
   

  
display_segmentation()

this script outputs this one:


the groundtruth array contains the class number of what class should be present in an image, the prediction array contains the class number of what the neural net thinks the image is composed of

I would recommend you don’t start from vanilla PyTorch as there are multiple decent libraries built on top of it that simplify building models.
You could try lightning bolts for baseline model architecture + training procedure (you just need to replace the provided dataset with your own. For IoU implementation you can use torchmetrics.