Torch Batch Norm

How to copy the weights from one model to another (same model architecture) only for batch norm layer.

You can access the corresponding batchnorm layers and copy their state_dict via:

modelA.batch_norm_layer1.load_state_dict(modelB.batch_norm_layer1.state_dict())

If you need to copy multiple layers, you can of course also iterate the model.children() and use an if condition to filter for e.g. isinstance(child, nn.BatchNorm2d).

But from the above code, we have to write code for each layer to load.
Let say, we have switched the model now. We have again changed the code.
Is there any way the code is dynamic?

In this case you can follow the second part of my post e.g. by iterating the modules and using conditions:

modelA = models.resnet50()
modelB = models.resnet50()

for (nameA, childA), (nameB, childB) in zip(modelA.named_modules(), modelB.named_modules()):
    if isinstance(childA, nn.BatchNorm2d) and isinstance(childB, nn.BatchNorm2d):
        print("Found {} and {}".format(nameA, nameB))
        childA.load_state_dict(childB.state_dict())

# verify
for (nameA, childA), (nameB, childB) in zip(modelA.named_modules(), modelB.named_modules()):
    if isinstance(childA, nn.BatchNorm2d) and isinstance(childB, nn.BatchNorm2d):
        print((childA.weight == childB.weight).all())

In the end the right approach depends on your use case and if you want to use the module names, a lookup table, or just iterate the modules assuming they were created in the same order.