Model distillation with mixed-precision training

Hi,

I am doing model distillation with mixed-precision training. But my code gives me some strange results: the results seem very different from training without distillation even if the distillation weight is set to zero. I wonder if I did something wrong in my use of mixed-precision training. The code is as follows

with torch.cuda.amp.autocast(enabled=True):
    outputs_student = model_student(inputs, targets)
    with torch.no_grad():
        outputs_teacher = model_teacher(inputs, targets)
    loss_distillation = distill_loss(outputs_student, outputs_teacher)
    loss_student = some_loss(outputs_student)    
    loss = loss_student + weight*loss_distillation 
optimizer.zero_grad()
scaler.scale(loss).backward()
scaler.step(self.optimizer)
scaler.update()

where scaler = torch.cuda.amp.GradScaler(enabled=True). As mentioned above, I tried to set weight to 0, but the result was still very different from training using loss_student only. Is there anything wrong in my use of mixed-precision training?

Just a follow-up. I find the gradient of some layer weights is nan after scaler.scale(loss).backward(), although the weights look find after scaler.step(self.optimizer). This seems only happen for mixed precision training. Is this normal?

Yes, since invalid gradients can be created by the scaler, which will trigger the optimizer to skip this parameter update as described in the docs.

1 Like

Thanks for your reply. Just another question relevant to this one. In the code below:

outputs_student = model_student(inputs, targets)
with torch.no_grad():
    outputs_teacher = model_teacher(inputs, targets)
loss_distillation = distill_loss(outputs_student, outputs_teacher)

model_student is actually wrapped by torch.nn.parallel.DistributedDataParallel for multi-GPU training. Given that the teacher model is in eval mode and does not require gradient during training, is it ok to not wrap model_teacher with DDP?

Yes, I believe this should be fine since you are computing a static target with it used to compute the loss. DDP won’t need to synchronize any data.

1 Like