Gradients not computed after replacing BatchNorm with GroupNorm

I’m fine-tuning Resnet18, and wanted to try replacing BatchNorms with GroupNorms to see if I can do smaller batches, but somehow, none of the weights in the network are updating when I replace BatchNorms with GroupNorms.

Here is my code to replace BatchNorms:

def convert_bn_model_to_gn(module, num_groups=16):
    mod = module
    if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
        num_groups = 1
        mod = torch.nn.GroupNorm(num_groups, module.num_features,
                           eps=module.eps, affine=module.affine)
        # mod = torch.nn.modules.linear.Identity()
        if module.affine:
            mod.weight.data = module.weight.data.clone().detach()
            mod.bias.data = module.bias.data.clone().detach()
    for name, child in module.named_children():
        mod.add_module(name, convert_bn_model_to_gn(child, num_groups=num_groups))
    del module
    return mod

Inspecting the parameters (even parameters of the dense layer at the top, right before the loss), the gradients are zeros (after calling loss.backward()), and the weights don’t change. When I keep the BatchNorms, it works fine. Even when I replace BatchNorms with BatchNorms, it works.

Somehow, when replacing BNs with GNs, it disconnects the graph or something and doesn’t computer gradients. Any ideas what I’m doing wrong? Or ideas how to debug?

1 Like