Training two models simultaniously: incorrect backpropagation

When training two networks for online knowledge distillation on CUDA(2 GPU), the backpropagation for the teacher model fails. When executing, the code runs properly but the teacher’s loss is not changing more than within a range of 0.0005 which is too low. When training the teacher seperately, the loss drops with a factor 0.8 every minibatch. The student’s backward is going well.

The models are trained with nn.DataParallel on the same device. When trying training it on seperate CUDA devices the same problem occured.

    criterion = nn.BCELoss()
    optimizer_t = optim.AdamW(student.parameters(),lr=teacher.lr,weight_decay=teacher.wgt_decay)
    optimizer_s = optim.AdamW(student.parameters(),lr=student.lr,weight_decay=student.wgt_decay)
# TRAINING #
            if train_network:
                teacher.train()
                student.train()
            else:
                student.eval()
                teacher.eval()
            for i, data in enumerate(trainloader, 0):
                iTot += 1
                inputs, labels = data[0].to(device), data[1].to(device)  # assign device (X)
                inputs = (inputs - torch.from_numpy(m)[np.newaxis, np.newaxis, :].to(device)) / torch.sqrt(
                    torch.from_numpy(v)[np.newaxis, np.newaxis, :].to(device))

                sequences += inputs.shape[0]
                frames += inputs.shape[0] * inputs.shape[1]
                hours += (inputs.shape[0] * inputs.shape[1]) * (
                        trainset.datasets[0].frShift / trainset.datasets[0].fs) / (60 * 60)
                # teacher training
                outputs_teacher = teacher(inputs)
                outputs_teacher = outputs_teacher[0]
                optimizer_t.zero_grad()
                teacher_loss = criterion(outputs_teacher, labels)
                teacher_loss.backward()
                optimizer_t.step()
                torch.cuda.empty_cache()

                
                # student training
                optimizer_s.zero_grad()
                outputs, transmitted_feat = student(inputs.detach())  # ap = after pooling != amplitude phase
                student_loss = criterion(outputs, labels.detach())
                
                loss2 = student_loss + teacher_loss

                loss2.backward()
                optimizer_s.step()
                torch.cuda.empty_cache()

Help will be much appreciated.

The problem was solved by doing input.detach() and label.detach()