[Knowledge Distillation]: NaNs when combining CE + aux MSE in the same backward (BatchNorm state mismatch)

Hi all,

I’m implementing knowledge-distillation of a Student ResNet-50 from a Teacher ResNet-101 for speaker-Verification task.
The total loss I want is

L = α · CE(logits_student, y)        # classification
  + (1-α) · MSE( features_student[layer3,4],
                 features_teacher[layer3,4] )

Both networks use SyncBatchNorm and I train in distributed mixed precision (autocast + GradScaler).


1. What goes wrong if I do “the naïve way”

# ↓ one forward pass that produces both terms
logits, f3, f4 = student_forward_return_everything(x)

ce  = criterion_CE(logits, y)
with torch.no_grad():
    t3, t4 = teacher_feats(x)        # teacher frozen

mse = 0.5 * F.mse_loss(f3, t3) + 0.5 * F.mse_loss(f4, t4)

loss = α*ce + (1-α)*mse
scaler.scale(loss).backward()        #  <-- NaNs here!

Error (excerpt):

RuntimeError: Function 'ConvolutionBackward0' returned nan values in its 1th output

Observation

  • The first term (ce) is computed from the normal forward that ends with the classifier head (uses BN running-stats V1).
  • The second term (mse) relies on intermediate activations.
    When I reuse those same activations in the joint loss, the backward runs through BNs whose buffers are now updated again by the CE pass (stats ≠ saved stats) → NaNs / Saved tensors... errors.

2. Workaround that does train stably

I split the optimisation in two forwards / two backward calls so each backward sees its own, consistent BN stats.

manual_step = True          # tells the main loop not to call step() for me
optimizer.zero_grad()

# ---------- pass 1 : CE ----------
logits = student(x, y)                    # logits WITH margin etc.
ce_loss = criterion_CE(logits, y)
scaler.scale(α * ce_loss).backward(retain_graph=True)

# ---------- pass 2 : feature MSE ----------
with torch.no_grad():
    t_feats = teacher.extract_intermediate_features(x, True)

s_feats = student.module.extract_intermediate_features(x, False)
mse  = 0.5*F.mse_loss(s_feats[3], t_feats[3])
mse += 0.5*F.mse_loss(s_feats[4], t_feats[4])

scaler.scale((1-α) * mse).backward()

# step
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(student.parameters(), 1.0)
scaler.step(optimizer)
scaler.update()

No more NaNs – training converges.

I applied the same idea to a Hinton KD (CE + KL) that needs logits with and without AM-Softmax margin; two passes work fine there as well.


3. Question(s)

  1. Is the “two-pass / two-backward” pattern the right way to keep BatchNorm happy in this scenario,
    or is there a better recipe to compute a combined loss and still avoid BN-buffer mismatch?
  2. Any other tricks (e.g., hooks that cache intermediate features, gradient checkpointing, gradient scaling gotchas) that would let me simplify this?

Any insight appreciated – I just want to know whether my fix is the standard way, or if the community has a cleaner pattern ?

Thanks!