Freeze BatchNorm layer lead to NaN

Hi, everyone

I want to freeze BatchNorm while fine-tuning my resnet (I mean, use global mean/std and freeze weight and bias in BN), but the loss is so large and become nan at last:

iter =  0 of 20000 completed, loss =  [ 15156.56640625]
iter =  1 of 20000 completed, loss =  [ nan]
iter =  2 of 20000 completed, loss =  [ nan]

the code I used to freeze BatchNorm is:

def freeze_bn(model):
    for name, module in model.named_children():
        if isinstance(module, nn.BatchNorm2d):
            module.eval()
            print 'freeze: ' + name
        else:
            freeze_bn(module)
model.train()
freeze_bn(model)

if I delete ‘freeze_bn(model)’, the loss converge:

iter =  0 of 20000 completed, loss =  [ 27.71678734]
iter =  1 of 20000 completed, loss =  [ 15.12455177]
iter =  2 of 20000 completed, loss =  [ 16.49391365]
iter =  3 of 20000 completed, loss =  [ 16.47186661]
iter =  4 of 20000 completed, loss =  [ 6.9540534]
iter =  5 of 20000 completed, loss =  [ 7.13955498]
iter =  6 of 20000 completed, loss =  [ 4.7441926]
iter =  7 of 20000 completed, loss =  [ 15.24151039]
iter =  8 of 20000 completed, loss =  [ 12.98035049]
iter =  9 of 20000 completed, loss =  [ 3.7848444]
iter =  10 of 20000 completed, loss =  [ 4.14818573]
iter =  11 of 20000 completed, loss =  [ 4.04237747]
iter =  12 of 20000 completed, loss =  [ 4.52667046]
iter =  13 of 20000 completed, loss =  [ 4.85921001]
iter =  14 of 20000 completed, loss =  [ 3.59978628]

Why the global mean and std make the loss nan?
Hope for help, Thank you!

1 Like
def set_bn_eval(m):
    classname = m.__class__.__name__
    if classname.find('BatchNorm') != -1:
      m.eval()

model.apply(set_bn_eval)

you should use apply instead of searching its children.
named_children() doesn’t iteratively search submodules.

4 Likes

What is your learning rate ??
your loss is 15156 ,don’t you think it is too big?

But
else: freeze_bn(module)
freeze layer recursively.

I encountered with similar problem after freezing BN. Yet my loss doesn’t raise to a very high magnitude, but just shifts to NaN abruptly at some epoch. Anyway, the problem seems gone after I tuning down the learning rate, but I still don’t know the cause of it.