Computing mIoU during validation

I am working on a binary segmentation task and have implemented the following training and validation loop. I need help with two points:

  1. How can I compute the IoU for each class after every epoch and print the Class 1 IoU, Class 2 IoU, and the overall mIoU score?
  2. Is it better to save the model based on the best mIoU score or the lowest validation loss?

Any guidance would be greatly appreciated. Thanks!

Here is my code:

# Initialize lists to store loss values
train_losses = []
val_losses = []

# Training and validation loop
for epoch in range(n_eps):
    model.train()
    train_loss = 0.0

    # Training loop
    for images, masks in tqdm(train_loader):
        images, masks = images.to(device), masks.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()

    avg_train_loss = train_loss / len(train_loader)
    train_losses.append(avg_train_loss)
    print(f"Epoch [{epoch+1}/{n_eps}], Train Loss: {avg_train_loss:.4f}")

    model.eval()
    val_loss = 0.0

    # Validation loop
    with torch.no_grad():
        for images, masks in val_loader:
            images, masks = images.to(device), masks.to(device)
            outputs = model(images)
            val_loss += criterion(outputs, masks).item()

    avg_val_loss = val_loss / len(val_loader)
    val_losses.append(avg_val_loss)
    print(f"Epoch [{epoch+1}/{n_eps}], Val Loss: {avg_val_loss:.4f}")

I like to use torchmetrics miou. You just need to pass your target and the models prediction, I think this function its a good choice for your problem.

Answering the second question, it depends of what you are looking for. If you are just training and want to use the model to test i suggest to take the highest miou, but if you are training for using it as a pretrained model maybe taking the loss will be better.

@Eduardo_Lawson thank you for response. I did try with torchmetrics but the problem is that it is not giving individual class scores for the binary segmentation.