Hi all,
I’m implementing knowledge-distillation of a Student ResNet-50 from a Teacher ResNet-101 for speaker-Verification task.
The total loss I want is
L = α · CE(logits_student, y) # classification
+ (1-α) · MSE( features_student[layer3,4],
features_teacher[layer3,4] )
Both networks use SyncBatchNorm and I train in distributed mixed precision (autocast
+ GradScaler
).
1. What goes wrong if I do “the naïve way”
# ↓ one forward pass that produces both terms
logits, f3, f4 = student_forward_return_everything(x)
ce = criterion_CE(logits, y)
with torch.no_grad():
t3, t4 = teacher_feats(x) # teacher frozen
mse = 0.5 * F.mse_loss(f3, t3) + 0.5 * F.mse_loss(f4, t4)
loss = α*ce + (1-α)*mse
scaler.scale(loss).backward() # <-- NaNs here!
Error (excerpt):
RuntimeError: Function 'ConvolutionBackward0' returned nan values in its 1th output
Observation
- The first term (
ce
) is computed from the normal forward that ends with the classifier head (uses BN running-stats V1). - The second term (
mse
) relies on intermediate activations.
When I reuse those same activations in the joint loss, the backward runs through BNs whose buffers are now updated again by the CE pass (stats ≠ saved stats) → NaNs /Saved tensors...
errors.
2. Workaround that does train stably
I split the optimisation in two forwards / two backward calls so each backward sees its own, consistent BN stats.
manual_step = True # tells the main loop not to call step() for me
optimizer.zero_grad()
# ---------- pass 1 : CE ----------
logits = student(x, y) # logits WITH margin etc.
ce_loss = criterion_CE(logits, y)
scaler.scale(α * ce_loss).backward(retain_graph=True)
# ---------- pass 2 : feature MSE ----------
with torch.no_grad():
t_feats = teacher.extract_intermediate_features(x, True)
s_feats = student.module.extract_intermediate_features(x, False)
mse = 0.5*F.mse_loss(s_feats[3], t_feats[3])
mse += 0.5*F.mse_loss(s_feats[4], t_feats[4])
scaler.scale((1-α) * mse).backward()
# step
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(student.parameters(), 1.0)
scaler.step(optimizer)
scaler.update()
No more NaNs – training converges.
I applied the same idea to a Hinton KD (CE + KL) that needs logits with and without AM-Softmax margin; two passes work fine there as well.
3. Question(s)
- Is the “two-pass / two-backward” pattern the right way to keep BatchNorm happy in this scenario,
or is there a better recipe to compute a combined loss and still avoid BN-buffer mismatch? - Any other tricks (e.g., hooks that cache intermediate features, gradient checkpointing, gradient scaling gotchas) that would let me simplify this?
Any insight appreciated – I just want to know whether my fix is the standard way, or if the community has a cleaner pattern ?
Thanks!