UNet Model not learning for multiclass image segmentation task

Problem: UNet model is not learning even after100 epochs. It shows the same loss for both training and validation.

Background information:
Each pixel of the Image can have only one label (in total 4 classes as 0,1,2, or 3).
From the DataSet:
→ label shape from OCTDataset: (496, 512)
→ img shape from OCTDataset: (496, 512)
→ image dtype from OCTDataset: float32, label dtype from OCTDataset: float32

→ shape just after transform of the label: torch.Size([1, 512, 512])
–>shape just after transform of the img: torch.Size([1, 512, 512])
–>image dtype after transform: torch.float32, label dtype after transform: torch.float32

My UNet Model:

#####  U-Net Model
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),

    def forward(self, x):
        return self.conv(x)

class UNET(nn.Module):
    def __init__(
            self, in_channels=1, out_channels=3, features=[64, 128, 256, 512],
        super(UNET, self).__init__()
        self.ups = nn.ModuleList()
        self.downs = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # Down part of UNET
        for feature in features:
            self.downs.append(DoubleConv(in_channels, feature))
            in_channels = feature

        # Up part of UNET
        for feature in reversed(features):
                    feature*2, feature, kernel_size=2, stride=2,
            self.ups.append(DoubleConv(feature*2, feature))

        self.bottleneck = DoubleConv(features[-1], features[-1]*2)
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)
        # softmax

    def forward(self, x):
      skip_connections = []

      for down in self.downs:
        x = down(x)
        x = self.pool(x)

      x = self.bottleneck(x)
      skip_connections = skip_connections[::-1]

      for idx in range(0, len(self.ups), 2):
        x = self.ups[idx](x)
        skip_connection = skip_connections[idx//2]

        if x.shape != skip_connection.shape:
          x = TF.resize(x, size=skip_connection.shape[2:])

        concat_skip = torch.cat((skip_connection, x), dim=1)
        x = self.ups[idx+1](concat_skip)

      return self.softmax(self.final_conv(x))

My code for training:

# Training function
def train(model, loader, optimizer, loss_fn, device):
    epoch_loss = 0.0

    for x, y in loader:
      x = x.to(device, dtype=torch.float32)
      y = y.to(device, dtype=torch.float32)        

      y_pred = model(x)
      # removing channel from orig label
      y = y.squeeze(1)

      loss = loss_fn(y_pred, y.type(torch.LongTensor).cuda())
      epoch_loss += loss.item()* x.size(0)
    epoch_loss = epoch_loss/len(loader)
    return epoch_loss

# Evaluation function
def evaluate(model, loader, loss_fn, device):
    epoch_loss = 0.0

    with torch.no_grad():
      for x, y in loader:
        x = x.to(device, dtype=torch.float32)
        y = y.to(device, dtype=torch.float32)

        y_pred = model(x)

        # removing channel from orig label
        y = y.squeeze(1)

        loss = loss_fn(y_pred, y.type(torch.LongTensor).cuda())
        epoch_loss += loss.item()* x.size(0)     

      epoch_loss = epoch_loss/len(loader)

    return epoch_loss
""" Hyperparameters """
H = 512
W = 512
size = (H, W)
num_epochs = 100
lr = 1e-4 # 0.001
device = torch.device('cuda')   ## 

""" Calculate the time taken """
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

class_weight = [0.9851726793996959,  0.005821095764414637,  0.005777171078234721,  0.0032290537576547757] # class weights of whole dataset (before making subset of training, validation, and test dataset)
class_weight_tensor = torch.Tensor(class_weight).to(device, dtype=torch.float32)

model = UNET(in_channels=1, out_channels=4)
model = model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5, verbose=True)
# loss function 
loss_fn = nn.CrossEntropyLoss(weight=class_weight_tensor)

# lists to collects data while treaining
best_valid_loss = float("inf")
train_losses_list = list()

# if you want to resume training from check point then change it to True
resume_Training = False

Training loop

""" Training the model """
for epoch in range(num_epochs):
  # resume
  if resume_Training:
    checkpoint = torch.load('/content/drive/MyDrive/Practical_work/May2022/outputs/model_V2.pth')
    epoch = checkpoint['epoch']
    train_losses_list = checkpoint['loss_train_list']
    val_losses_list = checkpoint['loss_val_list']
    best_valid_loss = val_losses_list[-1]
    resume_Training = False

  start_time = time.time()

  train_loss = train(model, oct_trainingloader, optimizer, loss_fn, device = device)
  valid_loss = evaluate(model, oct_validationloader, loss_fn, device = device) # epoch_loss, precision, recall, thresholds, iou_val

  # append 

  """ Saving the model """
  if valid_loss < best_valid_loss:
    data_str = f"Valid loss improved from {best_valid_loss:2.4f} to {valid_loss:2.4f}. Saving checkpoint"

    best_valid_loss = valid_loss
    state = {'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss_train_list': train_losses_list,
            'loss_val_list': val_losses_list,
    torch.save(state, f'/content/drive/MyDrive/Practical_work/May2022/outputs/model_V2.pth')        

  end_time = time.time()
  epoch_mins, epoch_secs = epoch_time(start_time, end_time)

  data_str = f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s'
  data_str += f'\tTrain Loss: {train_loss:.3f}'
  data_str += f'\t Val. Loss: {valid_loss:.3f}'


# Ploting losses

nn.CrossEntropyLoss expects raw logits as the model outputs, not probabilities, so remove the nn.Softmax from your model and rerun the training.

I tried after removing softmax from UNet model. But the training does not seem right. I tried with 30 epochs just to cross-check, but in the second epoch, loss dropped from 36% to 5%, which does not seem right.


Labels are highly imbalanced, so I used weights in CE loss. I calculated my weights as shown here:
Classes → 0 (background), and remaining three (1,2, and 3) are class indices in labels.

First I consider whole labels from the whole dataset (dataset before splitting into training, validation, and testing) and counted how many times each class appeared. With these counts, I calculated weights.

class_counts = [294219409,   1738456,   1725338,    964349]
total_count = sum(class_counts)
class_weight = [class_counts[0]/total_count, class_counts[1]/total_count, class_counts[2]/total_count, class_counts[3]/total_count ]

Later I fed this weights to CE Loss:

loss_fn = nn.CrossEntropyLoss(weight=class_weight_tensor)

Is this the correct way to calculate class weights?

Currently, I am only using CE Loss with class weights for training. I know Dice loss and Focal Loss would be ideal one but even with just CE Loss network should learn and work (even with not so good efficiency as compared with Dice Loss).

Do you see any mistakes in my process?

Thank you very much for your support!

Best Regards,

The loss is not given in a percentage and an initial drop is common, so could you explain your concern a bit?

No, since you are now weighting the majority class higher and are thus adding more penalty to a wrong classification of class0. Divide by the class frequencies as explained in this post.

I mean with 0.36 to 0.05!

Thanks for the link. It’s very helpful. Based on this link, I have calculated weights per batch for Cross entropy loss. Still, with training 100 epochs, the learning is quite slow.


I saw many code repo where they use a combination of Cross entropy loss and Dice Score, but the contribution of each one of them in final total loss varies (some use just addition of both the loss, and some use percentages of (50%) from each one of them). Are there any useful guidelines while calculating the loss function especially for imbalanced dataset?

Hello! im having the same issue. Did you solve this problem ankita?