Validation loss is large for multitask loss (Unet based)

Train loss is good but validation loss is large and not decreasing much. My model looks as follows

class DownSample_Block(nn.Module):
  def __init__(self,in_ch,out_ch,down_sample=True,batch_norm=True):
    super().__init__()
    self.down_sample = down_sample
    self.batch_norm = batch_norm
    self.pool = nn.MaxPool2d(2)
    self.bn = nn.BatchNorm2d(num_features = out_ch)
    self.conv1 = nn.Conv2d(in_channels=in_ch, out_channels=out_ch, kernel_size=3, padding=1)
    self.conv2 = nn.Conv2d(in_channels=out_ch, out_channels=out_ch, kernel_size=3, padding=1)


  def forward(self,x):
    if self.down_sample:
      x = self.pool(x)

    # conv ->bn-> relu
    x = self.conv1(x)
    if self.batch_norm:
      x = self.bn(x)
    x = F.relu(x)

    x = self.conv2(x)
    if self.batch_norm:
      x = self.bn(x)
    x = F.relu(x)

    return x



class Upsample_Block(nn.Module):
  def __init__(self,in_ch,out_ch,skip_ch):
    super().__init__()
    self.upsample = nn.ConvTranspose2d(in_ch,in_ch,kernel_size=2,stride=2)
    # self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
    self.bn = nn.BatchNorm2d(num_features=out_ch)
    self.conv1 = nn.Conv2d(in_channels=in_ch+skip_ch, out_channels=out_ch, kernel_size=3, padding=1)
    self.conv2 = nn.Conv2d(in_channels=out_ch, out_channels=out_ch, kernel_size=3, padding=1)
  
   

  def forward(self,x,skip):
    up = self.upsample(x)
    out = torch.cat((skip,up),axis=1)
    out = F.relu(self.bn(self.conv1(out)))
    out = F.relu(self.bn(self.conv2(out)))
    return out


class Unet(nn.Module):
    def __init__(self):
      super().__init__()
      bins = 15
      self.eblock1 = DownSample_Block(3,32,down_sample=False)
      self.eblock2 = DownSample_Block(32,64)
      self.eblock3 = DownSample_Block(64,128)
      self.eblock4 = DownSample_Block(128,256)
      self.eblock5 = DownSample_Block(256,512)

      self.dblock2 = Upsample_Block(512,256,256)
      self.dblock3 = Upsample_Block(256,128,128)
      self.dblock4 = Upsample_Block(128,64,64)
      self.dblock5 = Upsample_Block(64,32,32)
     
      # distance
      self.dist_conv = nn.Conv2d(32,bins,kernel_size=1)

      #segmentation
      self.seg_conv = nn.Conv2d(32+bins,2,kernel_size=1)
      

    def forward(self,x):
      d1 = self.eblock1(x)
      # print(d1.shape)
      d2 = self.eblock2(d1)
      # print(d2.shape)
      d3 = self.eblock3(d2)
      # print(d3.shape)
      d4 = self.eblock4(d3)
      # print(d4.shape)
      d5 = self.eblock5(d4)
      # print(d5.shape)

      u1 = self.dblock2(d5,d4)
      # print(u1.shape)
      u2 = self.dblock3(u1,d3) #d5
      # print(u2.shape)
      u3 = self.dblock4(u2,d2)
      # print(u3.shape)
      u4 = self.dblock5(u3,d1)
      # print(u4.shape)

      #distance out
      out_dist = self.dist_conv(u4)

      #segment out  # relu(out_dist) + u4 out
      out_relu = F.relu(out_dist)  
      out_concat = torch.cat([u4,out_relu],axis=1)
      out_seg = self.seg_conv(out_concat)


      # print(out_dist.shape,out_seg.shape)
      return out_seg,out_dist


My loss function and distance transform function is

def distance_map_batch_v2(Y_batch,threshold=20,bins=15):
    '''
    Computes distance map of https://arxiv.org/abs/1709.05932 .
    Threshold and number of bins are parameters explained in the paper.
    Y_batch is the one hot encoded mask of the map with buildings and background. Mask for background is stored on channel 0 and Mask for buildings is stored on channel 1. Y_batch is a torch tensor
    returns: torch tensor distance map
    
    '''
 
    Y_batch_dist = []
    for i in range(len(Y_batch)):
        msk = Y_batch[i].detach().cpu().numpy()
        msk = msk.transpose(1,2,0)
        distance_build=distance_transform_bf(msk[:,:,1],sampling=2)
        distance_background=distance_transform_bf(msk[:,:,0],sampling=2)
        distance_build=np.minimum(distance_build,threshold*(distance_build>0))
        distance_background=np.minimum(distance_background,threshold*(distance_background>0))
        distance=(distance_build-distance_background)
        distance=(distance-np.amin(distance))/(np.amax(distance)-np.amin(distance)+1e-50)*(bins-1)
        inp=torch.LongTensor(distance)
        inp_ = torch.unsqueeze(inp, len(distance.shape))
        one_hot = torch.FloatTensor(distance.shape[0], distance.shape[1], bins).zero_()
        one_hot.scatter_(len(distance.shape), inp_, 1)
        one_hot=np.asarray(one_hot)
        one_hot = one_hot.transpose(2,0,1)
        Y_batch_dist.append(one_hot)
      

    return torch.FloatTensor(np.asarray(Y_batch_dist))


def cce(y_true,y_est):
    '''
    Loss used --> cross entropy
    '''
    y_true = y_true.contiguous().view(-1,y_true.size()[-1])
    y_est = y_est.contiguous().view(-1,y_true.size()[-1])
    y_true_flat = y_true.max(-1)[1]
    # print(y_est.shape,y_true_flat.shape)
    loss_func = nn.CrossEntropyLoss()
    loss = loss_func(y_est,y_true_flat)#be careful inverse order of arguments 
    return loss

class MulticlassLoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(MulticlassLoss, self).__init__()

    def forward(self, pred_mask,pred_dist, y_batch, y_dist, smooth=1):

        #comment out if your model contains a sigmoid or equivalent activation layer

        bce_weight = 0.5
        
        loss_seg = cce(y_batch,pred_mask)
        # print(loss_seg)
        loss_dist = cce(y_dist,pred_dist)
        # print(loss_dist)
        loss = (bce_weight*loss_seg)+(bce_weight*loss_dist)
  
        return loss

Training code is

checkpoint_path = 'model2/chkpoint_'
best_model_path = 'model2/bestmodel.pt'

# model = model.to(device)
epochs = 50
criterion = MulticlassLoss()
learning_rate = 1e-3
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
dice_coef = DiceCoef()
valid_loss_min = 3.95275  # some random value


train_loss,val_loss = [],[]
train_iou,val_iou = [],[]
train_coef,val_coef = [],[]

for epoch in range(epochs):
    print('Epoch {}/{}'.format(epoch + 1, epochs))
    start_time = time.time()
    
    model.train()
    running_train_loss = []
    running_train_score = []
    running_train_coef = []
    for image,mask in train_loader: 
        image = image.to(device,dtype=torch.float)
        mask = mask.to(device,dtype=torch.float)
        true_dist = distance_map_batch_v2(mask)
        true_dist = true_dist.to(device,dtype=torch.float)
        pred_mask,pred_dist = model.forward(image) # forward propogation
        loss = criterion(pred_mask, pred_dist, mask, true_dist) 
        # print(pred_mask.shape,mask.shape)
        score = iou_batch(pred_mask,mask)
        coef = dice_coef(pred_mask,mask)
        optimizer.zero_grad() # setting gradient to zero
        loss.backward()
        optimizer.step()
        running_train_loss.append(loss.item())
        running_train_score.append(score)
        running_train_coef.append(coef.item())

    
    running_val_loss = []
    running_val_score = []
    running_val_coef = []
    model.eval()
    with torch.no_grad():
        for image,mask in val_loader:
            image = image.to(device,dtype=torch.float)
            mask = mask.to(device,dtype=torch.float)  
            true_dist = distance_map_batch_v2(mask)
            true_dist = true_dist.to(device,dtype=torch.float)                          
            pred_mask,pred_dist = model.forward(image)
            loss = criterion(pred_mask, pred_dist, mask, true_dist)
            score = iou_batch(pred_mask,mask)
            coef = dice_coef(pred_mask,mask)
            running_val_loss.append(loss.item())
            running_val_score.append(score)
            running_val_coef.append(coef.item())
    
    epoch_train_loss,epoch_train_score,epoch_train_coef = np.mean(running_train_loss) ,np.mean(running_train_score),np.mean(running_train_coef)
    print('Train loss : {} iou : {}, dice coef: {}'.format(epoch_train_loss,epoch_train_score,epoch_train_coef))                       
    train_loss.append(epoch_train_loss)
    train_iou.append(epoch_train_score)
    train_coef.append(epoch_train_coef)
    
    epoch_val_loss,epoch_val_score,epoch_val_coef = np.mean(running_val_loss),np.mean(running_val_score),np.mean(running_val_coef)
    print('Validation loss : {} iou : {}, dice coef: {}'.format(epoch_val_loss,epoch_val_score,epoch_val_coef))                                
    val_loss.append(epoch_val_loss)
    val_iou.append(epoch_val_score)
    val_coef.append(epoch_val_coef)


    # create checkpoint variable and add important data
    checkpoint = {
            'epoch': epoch + 1,
            'valid_loss_min': epoch_val_loss,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
        }
        
    # save checkpoint
    save_ckp(checkpoint, False, checkpoint_path, best_model_path)
    ## TODO: save the model if validation loss has decreased
    if epoch_val_loss <= valid_loss_min:
            print('Validation loss decreased ({:.6f} --> {:.6f}).  Saving model ...'.format(valid_loss_min,epoch_val_loss))
            # save checkpoint as best model
            save_ckp(checkpoint, True, checkpoint_path, best_model_path)
            valid_loss_min = epoch_val_loss
    
    time_elapsed = time.time() - start_time
    print('{:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))

Results is

Epoch 1/50
Train loss : 1.636072209051677 iou : 0.2590805115404147, dice coef: 0.4123131136809077
Validation loss : 3.644066361586253 iou : 0.3961057460856792, dice coef: 0.5682796835899353
Validation loss decreased (3.952750 --> 3.644066). Saving model …
32m 49s
Epoch 2/50
Train loss : 1.2269017398357391 iou : 0.24557374558415218, dice coef: 0.39523599807705195
Validation loss : 3.8924281001091003 iou : 0.4135372376490164, dice coef: 0.5854614973068237
14m 48s
Epoch 3/50
Train loss : 1.1182671712977545 iou : 0.23921578082771644, dice coef: 0.3875904355730329
Validation loss : 3.321150771776835 iou : 0.39828164579205444, dice coef: 0.5704469372828801
Validation loss decreased (3.644066 --> 3.321151). Saving model …
14m 55s
Epoch 4/50
Train loss : 1.047594209228243 iou : 0.2370423747555327, dice coef: 0.3851519699607577
Validation loss : 3.4572733004887897 iou : 0.413889831623608, dice coef: 0.5856322149435679
14m 38s
Epoch 5/50
Train loss : 0.9927580939871924 iou : 0.23708825814195278, dice coef: 0.38523539411170143
Validation loss : 3.2639930764834086 iou : 0.3895362960387864, dice coef: 0.560992560784022
Validation loss decreased (3.321151 --> 3.263993). Saving model …
14m 32s
Epoch 6/50
Train loss : 0.9615498576845442 iou : 0.23610517876921325, dice coef: 0.38409997075796126
Validation loss : 3.03681640625 iou : 0.3960668315457316, dice coef: 0.5678311238686243
Validation loss decreased (3.263993 --> 3.036816). Saving model …
14m 28s
Epoch 7/50
Train loss : 0.9389585907970156 iou : 0.23596643507862994, dice coef: 0.3839401804975101
Validation loss : 2.9654738306999207 iou : 0.38705827275984567, dice coef: 0.5584435482819875
Validation loss decreased (3.036816 --> 2.965474). Saving model …
14m 22s
Epoch 8/50
Train loss : 0.9142322753156934 iou : 0.2337625844423078, dice coef: 0.38102387764624185
Validation loss : 2.775194569428762 iou : 0.37924780132712754, dice coef: 0.5504745692014694
Validation loss decreased (2.965474 --> 2.775195). Saving model …
14m 12s
Epoch 9/50
Train loss : 0.8896246692964009 iou : 0.22980133446432147, dice coef: 0.3760204364146505
Validation loss : 2.7394471764564514 iou : 0.3958687851408545, dice coef: 0.5675292551517487
Validation loss decreased (2.775195 --> 2.739447). Saving model …
14m 23s
Epoch 10/50
Train loss : 0.8742967473609107 iou : 0.22970163538591346, dice coef: 0.375763281754085
Validation loss : 2.7312774141629537 iou : 0.43684185002205045, dice coef: 0.6085633635520935
Validation loss decreased (2.739447 --> 2.731277). Saving model …
14m 32s
Epoch 11/50
Train loss : 0.8534701717751366 iou : 0.2294650052856951, dice coef: 0.375523764533656
Validation loss : 2.4273023923238117 iou : 0.3948336292033707, dice coef: 0.5670865674813589
Validation loss decreased (2.731277 --> 2.427302). Saving model …
14m 28s
Epoch 12/50
Train loss : 0.8340641353811536 iou : 0.22669237089401914, dice coef: 0.372151409302439
Validation loss : 2.761715757846832 iou : 0.3619836216192266, dice coef: 0.5321707213918369
14m 46s
Epoch 13/50
Train loss : 0.8210224441119602 iou : 0.2280403249532117, dice coef: 0.37345540864127025
Validation loss : 2.6830695708592733 iou : 0.4035514459610498, dice coef: 0.575622484087944
14m 59s
Epoch 14/50
Train loss : 0.8059026994875499 iou : 0.2248244423588186, dice coef: 0.3693577834538051
Validation loss : 2.1368691086769105 iou : 0.412322404646021, dice coef: 0.5850987017154694
Validation loss decreased (2.427302 --> 2.136869). Saving model …
15m 0s
Epoch 15/50
Train loss : 0.7877428701945713 iou : 0.22854895431555933, dice coef: 0.3743239802973611
Validation loss : 2.2300844828287762 iou : 0.38135859358075624, dice coef: 0.5536836306254069
15m 10s
Epoch 16/50
Train loss : 0.7829681926539966 iou : 0.22874669823879584, dice coef: 0.3749655529856682
Validation loss : 2.4254708528518676 iou : 0.41929319700766815, dice coef: 0.5917253951231639
15m 20s
Epoch 17/50
Train loss : 0.7644439024584634 iou : 0.23029824865898385, dice coef: 0.3765100053378514
Validation loss : 2.4135544896125793 iou : 0.4461470567967697, dice coef: 0.6174438158671062
15m 24s
Epoch 18/50
Train loss : 0.7528764850326947 iou : 0.22475351346286607, dice coef: 0.36964563727378846
Validation loss : 2.076960837841034 iou : 0.46038506970180504, dice coef: 0.6315884411334991
Validation loss decreased (2.136869 --> 2.076961). Saving model …
15m 25s
Epoch 19/50
Train loss : 0.7382002962487084 iou : 0.22603396174056087, dice coef: 0.3714781603642872
Validation loss : 2.2174523651599882 iou : 0.45246894388506703, dice coef: 0.623819363117218
14m 40s
Epoch 20/50
Train loss : 0.726068066912038 iou : 0.22376600579043568, dice coef: 0.3686292886734009
Validation loss : 1.9367382168769836 iou : 0.39958022276006816, dice coef: 0.5725013802448908
Validation loss decreased (2.076961 --> 1.936738). Saving model …
13m 45s
Epoch 21/50
Train loss : 0.714294318641935 iou : 0.22446068206685127, dice coef: 0.36995240322181155
Validation loss : 2.322433396180471 iou : 0.4313971246092888, dice coef: 0.6031914939483006
13m 46s
Epoch 22/50
Train loss : 0.6834765095795904 iou : 0.21772951261682544, dice coef: 0.3606676868030003
Validation loss : 2.0876523753007254 iou : 0.45842205361599536, dice coef: 0.6297118236621221
13m 40s
Epoch 23/50
Train loss : 0.6793745417680059 iou : 0.21839416961569164, dice coef: 0.36152691223791666
Validation loss : 2.267058895031611 iou : 0.47621269970422203, dice coef: 0.6458630830049514
13m 40s

Is my model overfitting. What might be reason here? Please someone help to solve this issue.

idea from paper

@ptrblck please help

Currently, I am experiencing the same issue. Have you been able to able fix it? If yes, your approach will be appreciated.