Loss increase during training

Hello everyone’s. I try to use Mean-Teacher architecture with two U-Net in a semi-supervised setup but I have a strage behaviour with the loss during the training. I have a supervised loss and a contrastive loss computed on unlabled data that then i sum it and obatin a loss for backpropagation


This is an example of my training. It would appear that contrastive loss contribution dont’ backpropagate.

I had checked if the loss contribution had required_grad = True and all my loss have this parameter setted at True value.

This is my trainer function:

    def _train(self,writer):
        self.model.train()  # train mode
        self.ema_model.train()  # train mode
        cons = torch.as_tensor(0, dtype=torch.float32, device=device)
        cons_weight = torch.as_tensor(0, dtype=torch.float32, device=device)

        temp = []  # accumulate the losses here
        n_batch = 0

        batch_iter = tqdm(enumerate(self.training_DataLoader), 'Training', total=len(self.training_DataLoader), leave=False)
        for i_batch, (x,y) in batch_iter:#enumerate(trainloader):
            n_batch += 1

            volume_batch, label_batch = x,y[0]#sampled_batch['image'], sampled_batch['label']
            volume_batch, label_batch = volume_batch.cuda(), label_batch.cuda()
            unlabeled_volume_batch = volume_batch[self.labeled_bs:]

            noise = torch.clamp(torch.randn_like(unlabeled_volume_batch) * 0.1, -0.2, 0.2)
            ema_inputs = unlabeled_volume_batch + noise
            outputs = model(volume_batch)
            with torch.no_grad():
                ema_output = ema_model(ema_inputs)
            T = 8
            volume_batch_r = unlabeled_volume_batch.repeat(2, 1, 1, 1)
            stride = volume_batch_r.shape[0] // 2
            preds = torch.zeros([stride * T, 1, 384, 384]).cuda()
            for i in range(T//2):
                ema_inputs = volume_batch_r + torch.clamp(torch.randn_like(volume_batch_r) * 0.1, -0.2, 0.2)
                with torch.no_grad():
                    preds[2 * stride * i:2 * stride * (i + 1)] = ema_model(ema_inputs)
            preds = F.softmax(preds, dim=1)
            preds = preds.reshape(T, stride, 1, 384, 384)
            preds = torch.mean(preds, dim=0)  #(batch, 2, 112,112,80)
            # uncertainty = -1.0*torch.sum(preds*torch.log(preds + 1e-6), dim=1, keepdim=True) #(batch, 1, 112,112,80)
            uncertainty = -1.0*torch.sum(label_batch[self.labeled_bs:]*torch.log(preds + 1e-6), dim=1, keepdim=True) #(batch, 1, 112,112,80)

            weights = F.softmax(1 - uncertainty, dim=0)         # weight map = softmax(confidence map)
            ema_probs = torch.sum(preds * weights, dim=0)   # integration result
            ema_seg_uncertainty = -1.0 * torch.sum(ema_probs * torch.log2(ema_probs + 1e-6), dim=1, keepdim=True)   # U_seg

            ## calculate the loss
            supervised_loss = criterion(outputs[:self.labeled_bs], label_batch[:self.labeled_bs]) 


            # consistency_weight = get_current_consistency_weight(iter_num//150)
            consistency_weight = get_current_consistency_weight(self.epoch)
            # consistency_weight = 0.1 * math.exp(-5 * math.pow((1 - self.epoch / self.epochs), 2))
            # consistency_weight = args.consistency * math.exp(-5 * math.pow((1 - self.epoch / self.epochs), 2))
            # consistency_dist = consistency_criterion(outputs[labeled_bs:], ema_output) #(batch, 2, 112,112,80)
            consistency_dist = torch.pow(outputs[self.labeled_bs:] - ema_output, 2) #(batch, 2, 112,112,80)
            # threshold = (0.75+0.25*ramps.sigmoid_rampup(self.iter_num, len(self.training_DataLoader)*self.epochs))*np.log(2)
            # mask = (uncertainty<threshold).float()
            # consistency_dist = torch.sum(mask*consistency_dist)/(2*torch.sum(mask)+1e-16)
            consistency_dist = consistency_dist * (1 - ema_seg_uncertainty)
            consistency_dist = torch.mean(consistency_dist)
            consistency_loss = consistency_weight * consistency_dist

            cons += consistency_loss
            cons_weight += consistency_weight

            if n_batch == 1:
                a = supervised_loss.item() + consistency_loss.item()
                b = supervised_loss.item()#dice.item()
            else:
                a = a + supervised_loss.item() + consistency_loss.item()
                b = b + supervised_loss.item()#dice.item()

            loss = supervised_loss + consistency_loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            update_ema_variables(model, ema_model, args.ema_decay, self.iter_num)

            self.iter_num = self.iter_num + 1

            logging.info('Epoch %d | iteration %d : loss : %f, sup_loss: %f, consistency_loss: %f, cons_dist: %f, loss_weight: %f' %
                         (self.epoch, self.iter_num, loss.item(),supervised_loss.item(), consistency_loss.item(), consistency_dist.item(), consistency_weight))
            
        
        self.report_losses.append(temp)
        print(" CONSISTENCY LOSS: ", round(cons.item() / n_batch, 5), cons_weight.item() / n_batch)
        losses = a / n_batch    ###
        print(" TRAIN LOSS: ", losses)
        self.task_losses.append([b / n_batch, round(cons.item() / n_batch, 5), cons_weight.item() / n_batch])
        self.training_loss.append(losses)
        self.consistency_loss.append(round(cons.item() / n_batch, 6))
        self.learning_rate.append(self.optimizer.param_groups[0]['lr'])
        torch.save(self.model.state_dict(), self.model_name)
        print("Model Saved")

and this is how create models:

def create_model(ema=False):
    # Network definition
    model = UNet(1, 1)
    model = model.cuda()
    if ema:
        for param in model.parameters():
            param.detach_()
    return model

model = create_model()
ema_model = create_model(ema=True)
criterion = MultiTaskLoss().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-4)

Thanks in advance