Implementation of multi-class dice loss with ignore_index=0

Hi everyone, I am trying to implement multi-class dice loss but I want to ignore a particular class with index=0, The below code runs without exception but the MIOU=0.0072 is always constant after 40 epochs.
what am I missing guys:

def one_hot_embedding(self, labels, num_classes):                                                                   
                                                                                                 
      y = torch.eye(num_classes)                                                                                      
      return y[labels].permute((0, 3, 1, 2)).cuda()                                                                   
                                                                                                                      
  def dice_loss_llJ(self, logits, true):                                                                              
      """Computes the Sørensen–Dice loss.                                                                             
      Note that PyTorch optimizers minimize a loss. In this                                                           
      case, we would like to maximize the dice loss so we                                                             
      return the negated dice loss.                                                                                   
      Args:                                                                                                           
          true: a tensor of shape [B, 1, H, W].                                                                       
          logits: a tensor of shape [B, C, H, W]. Corresponds to                                                      
              the raw output or logits of the model.                                                                  
          eps: added to the denominator for numerical stability.                                                      
      Returns:                                                                                                        
          dice_loss: the Sørensen–Dice loss.                                                                          
      """                                                                                                             
      eps = 1e-7                                                                                                      
      if self.ignore_index is not None:                                                                               
          mask = true != self.ignore_index                                                                            
                                                                                                                      
          logits = logits * mask.unsqueeze(1)                                                                         
          true_1_hot = self.one_hot_embedding((true * mask), self.nclasses) * mask.unsqueeze(1)  # N,H*W -> N,H*W, C  
                                                                                                                      
      else:                                                                                                           
          true_1_hot = self.one_hot_embedding(true, self.nclasses)                                                    
                                                                                                                      
      probas = F.softmax(logits, dim=1)                                                                               
                                                                                                                      
      true_1_hot = true_1_hot.type(logits.type())                                                                     
      dims = (0,) + tuple(range(2, true.ndimension()))                                                                
      intersection = torch.sum(probas * true_1_hot, dims)                                                             
      cardinality = torch.sum(probas + true_1_hot, dims)                                                              
                                                                                                                      
      dice_loss = (2. * intersection / (cardinality + eps)).mean()                                                    
      return 1 - dice_loss