I think this will work for you, just change it to your custom layer. Let us know if did work:
def replace_bn(module, name):
'''
Recursively put desired batch norm in nn.module module.
set module = net to start code.
'''
# go through all attributes of module nn.module (e.g. network or layer) and put batch norms if present
for attr_str in dir(module):
target_attr = getattr(m, attr_str)
if type(target_attr) == torch.nn.BatchNorm2d:
print('replaced: ', name, attr_str)
new_bn = torch.nn.BatchNorm2d(target_attr.num_features, target_attr.eps, target_attr.momentum, target_attr.affine,
track_running_stats=False)
setattr(module, attr_str, new_bn)
# iterate through immediate child modules. Note, the recursion is done by our code no need to use named_modules()
for name, immediate_child_module in module.named_children():
replace_bn(immediate_child_module, name)
replace_bn(model, 'model')
original post: How to modify a pretrained model