Got it working! I thought of your idea but didn’t see how to do it well. Perhaps you can share me your code to see how it should had been done. For now this is what works for me:
def replace_bn(module, name):
'''
Recursively put desired batch norm in nn.module module.
set module = net to start code.
'''
# go through all attributes of module nn.module (e.g. network or layer) and put batch norms if present
for attr_str in dir(module):
target_attr = getattr(m, attr_str)
if type(target_attr) == torch.nn.BatchNorm2d:
print('replaced: ', name, attr_str)
new_bn = torch.nn.BatchNorm2d(target_attr.num_features, target_attr.eps, target_attr.momentum, target_attr.affine,
track_running_stats=False)
setattr(module, attr_str, new_bn)
# iterate through immediate child modules. Note, the recursion is done by our code no need to use named_modules()
for name, immediate_child_module in module.named_children():
replace_bn(immediate_child_module, name)
replace_bn(model, 'model')
source post: How to modify a pretrained model
Thanks AlbanD!