Jointly training two deep neural networks

Hello, I have two deep neural networks, one for speech enhancement to remove noise from the speech signals and the other network is for intent classification. I am trying to train them jointly i.e. the output of the first network is passed directly to the second network as follows:

def _train_epoch(self, epoch=100):

    for i, (mixture, clean, name) in enumerate(self.train_data_loader):
        mixture = mixture.to(self.device, dtype=torch.float)
        clean = clean.to(self.device, dtype=torch.float)
        self.optimizer.zero_grad()
        enhanced = self.model(mixture).to(self.device)
        loss = self.loss_function(clean, enhanced)
        loss.backward()
        self.optimizer.step()
        
        for i, d in enumerate(train_set_generator):
            enhanced, l = d
            model_back.train()
            y = model_back(enhanced.float().to(device2))
            loss_back = loss_func_back(y[0], l[0].to(device2))
            print("Iteration %d in epoch%d--> loss = %f" % (i, epoch, loss_back.item()), end='\r')
            loss_back.backward()
            optimizer_back.step()
            optimizer_back.zero_grad()
            if i % 100 == 0:
              model_back.eval()
              correct = []
              for j, ev in enumerate(valid_set_generator):
                enhanced, label = ev
                y_eval = model_back(enhanced.float().to(device2))
                pred = torch.argmax(y_eval[0].detach().cpu(), dim=1)
                intent_pred = pred
                correct.append((intent_pred == label[0]).float())
                if j > 100:
                    break
              acc = np.mean(np.hstack(correct))
              intent_acc = acc
              iter_acc = '\n iteration %d epoch %d -->' %(i, epoch)
              print(iter_acc, acc, best_accuracy)
              if intent_acc > best_accuracy:
                improved_accuracy = 'Current accuracy {}, {}'.format(intent_acc, best_accuracy)
                print(improved_accuracy)
                torch.save(model_back.state_dict(), 'bestmodel.pkl')

Where the loss of the first DNN is the MSE loss and the loss of the second DNN is the Cross entropy loss. The problem is that after many iterations in the first epoch the second DNN starts to iterate on the first loop again not goes to the second loop. Any help will be appreciated.