Bad IoU and irregular prediction result

I am Working in a Multiclass Semantic Segmentation Project, which involves 4 classes (background not included) and a COCO Json as Data format (with an original dataset from 10 train images and 4 test images, but I also have a Consistent Data Augmentation Script where I Get 100 train images and 20 test images), Pytorch as framework and UNET Pretrained Model from SMP Library. When I Used DataSet and DataLoader methods to generate the batches for the model, those batches have the following shape:

batch_images, masks = batch batch_images.shape: torch.Size([1, 3, 1024, 1024]) type batch_images: <class 'torch.Tensor'> batch_images.dtype: torch.float32 batch_masks.shape: torch.Size([1, 4, 1024, 1024]) type batch_mask: <class 'torch.Tensor'> batch_mask.dtype: torch.float32

I also have a method inside the Dataset object to get the Class Frequencies (in a pixel number) and the class weights (calculated from inverse frequency method), where for Augmented dataset shows:

Class Frequencies_Train: [16707375. 3310340. 3782234. 3841229.] Class Weights_Train: tensor([0.1981, 1.0000, 0.8752, 0.8618]) Class Frequencies_Test: [1167281. 1984024. 1058666. 1753069.] Class Weights_Test: tensor([0.9070, 0.5336, 1.0000, 0.6039])

Then, I am using the following model a train Script:

num_classes = 4
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class_weights_train = torch.tensor(class_weights_train).to(device)
class_weights_test = torch.tensor(class_weights_test).to(device)

encoder = "resnet101"
e_weights = "imagenet"    

model_Unet = smp.Unet(
    encoder_name=f"{encoder}",
    encoder_weights=f"{e_weights}",
    in_channels=3,
    classes=num_classes,
)

criterion = nn.CrossEntropyLoss()
learning_rate = 0.0001
weight_decay = 0.0005
optimizer = torch.optim.Adam(model_Unet.parameters(), lr=learning_rate, weight_decay=weight_decay)
model_Unet.to(device)
model_Unet.train()

#Training loop
session = 135
session_folder = f"training_session_{session}"
os.makedirs(session_folder, exist_ok=True)
num_epochs = 300
checkpoint_interval = 100

train_losses = []
train_iou = []
test_losses = []
test_iou = []

for epoch in range(num_epochs):
    model_Unet.train()
    train_loss = 0.0
    intersection = 0
    union = 0
    for i, (images, masks) in enumerate(dataloader_train):
        optimizer.zero_grad()
        images = images.to(device)
        masks = masks.squeeze(1).float().to(device)
        outputs = model_Unet(images)
        loss = criterion(outputs, masks)
        class_weights = torch.tensor(class_weights_train, dtype=torch.float32, device=device)
        weighted_loss = loss * class_weights
        train_loss += torch.sum(weighted_loss).item()
        loss.backward()
        optimizer.step()

        
        predicted_masks = torch.argmax(outputs, dim=1).float()
        intersection += torch.sum(predicted_masks * masks).item()
        union += torch.sum((predicted_masks + masks) > 0).item()

        # Loss
        train_loss += loss.item()
        print(f'Epoch: {epoch+1}/{num_epochs}\t Iteration: {i+1}/{len(dataloader_train)}')

    # Average Loss
    train_loss /= len(dataloader_train.dataset)

    # IoU
    iou_train = intersection / union

    train_losses.append(train_loss)
    train_iou.append(iou_train)
    print(f'Epoch: {epoch+1}/{num_epochs}\t Training Loss: {train_loss}\t IoU: {iou_train}')

    model_Unet.eval()
    test_loss = 0.0
    intersection_test = 0
    union_test = 0
    with torch.no_grad():
        for images, masks in dataloader_test:
            images = images.to(device)
            masks = masks.squeeze(1).float().to(device)
            outputs = model_Unet(images)
            loss = criterion(outputs, masks)
            class_weights = torch.tensor(class_weights_test, dtype=torch.float32, device=device)
            weighted_loss = loss * class_weights
            test_loss += torch.sum(weighted_loss).item()
       
            predicted_masks = torch.argmax(outputs, dim=1).float()
            intersection_test += torch.sum(predicted_masks * masks).item()
            union_test += torch.sum((predicted_masks + masks) > 0).item()

    # Average Test Loss
    test_loss /= len(dataloader_test.dataset)

    # IoU
    iou_test = intersection_test / union_test

    test_losses.append(test_loss)
    test_iou.append(iou_test)

The IoU usually is not consistent:

  1. Using the original Database, which was more imbalanced, the IoU Reached 0.42 and then dropped to 0 un 1 Epoch (whit this I mean that, for example, in the Epoch 290 the IoU value is 0.4 and then in the epoch 291 the IoU is 0.
  2. Using the “Cleaned” Database, talking about the Augmented and Non-Augmented, The IoU have a similar behavior but with a Lower Value (Reaches 0.3 before drops to 0).

About the Inference, Is in a “mid-point”. I mean, it isn’t perfect but it’s not at all bad (segmentate the background as the First Class and I am not including Background in the classes).

I’ve thinking in implement the DICE LOSS function from SMP repository, but I’m not sure how to do it, and if I must use the class Weights for that.

  1. Do you have any tips for Fix and improve the IoU and Inference?