How to modify a pretrained model

I got it working! This works:

def replace_bn(module, name):
    '''
    Recursively put desired batch norm in nn.module module.

    set module = net to start code.
    '''
    # go through all attributes of module nn.module (e.g. network or layer) and put batch norms if present
    for attr_str in dir(module):
        target_attr = getattr(m, attr_str)
        if type(target_attr) == torch.nn.BatchNorm2d:
            print('replaced: ', name, attr_str)
            new_bn = torch.nn.BatchNorm2d(target_attr.num_features, target_attr.eps, target_attr.momentum, target_attr.affine,
                                          track_running_stats=False)
            setattr(module, attr_str, new_bn)

    # iterate through immediate child modules. Note, the recursion is done by our code no need to use named_modules()
    for name, immediate_child_module in module.named_children():
        replace_bn(immediate_child_module, name)

replace_bn(model, 'model')

the crux is that you need to recursively keep changing the layers (mainly because sometimes you will encounter attributes that have modules itself). I think better code than the above would be to add another if statement (after the batch norm) detecting if you have to recurse and recursing if so. The above works to but first changes the batch norm over the outer layer (i.e. the first loop) and then with another loop making sure no other object that should be recursed is missed (and then recursing).

SO: https://stackoverflow.com/questions/58297197/how-to-change-activation-layer-in-pytorch-pretrained-module/64161690#64161690

credits: Replacing convs modules with custom convs, then NotImplementedError