Testing the Binary Segmentation Model and Calculating IoU

Hello, I am new to machine learning coding and I am training a binary segmentation model using Pytorch. I have loaded the trained model. I have around 120 test images. Can anyone please help me is my testing and calculating IoU score correct?

batch_size = 8
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
# Load the saved model
model = EncoderDecoder(in_channels=1, out_channels=1)  
model.load_state_dict(torch.load('best_model.pth'))
model.eval()  # Set the model to evaluation model

# Define lists to store predicted masks and ground truth labels
predicted_masks = []
ground_truth_labels = []
original_images = []

# Iterate over the test dataset
for inputs, labels in test_loader:
    # Move data to GPU
    inputs, labels = inputs.to(device), labels.to(device)
    original_images.append(inputs)

    # Forward pass
    with torch.no_grad():
        outputs = model(inputs)

    # Convert outputs to binary masks using a threshold 
    predicted_masks.append((outputs > 0.5).float())  # Applying threshold directly and keeping as float tensor
    
    # Append labels directly as tensors
    ground_truth_labels.append(labels)

# Concatenate the ground truth labels
ground_truth_labels = torch.cat(ground_truth_labels, dim=0)

# Concatenate original images
original_images = torch.cat(original_images, dim=0)

# Stack the predicted masks into a single tensor
predicted_masks = torch.cat(predicted_masks, dim=0)

torch.Size([120, 1, 650, 1250]) #shape of predicted_masks

///Calculating IoU Score
def iou_pytorch(inputs: torch.Tensor, targets: torch.Tensor, smooth: float = 1e-6):
    inputs = inputs.contiguous().view(-1)
    targets = targets.view(-1)

    intersection = (inputs * targets).sum().float()
    total = (inputs + targets).sum().float()
    union = total - intersection 

    iou = (intersection + smooth) / (union + smooth)
    
    return iou.item()

@ptrblck waiting…!!!

Please don’t tag specific users in questions, as it will discourage others to post valid answers.
You can compare your implementation with other libs (e.g. I guess torchmetrics or kornia would implement IOU already).