Replacing convs modules with custom convs, then NotImplementedError

Hi,

I want to replace Conv2d modules in an existing complex state-of-the-art neural network with pretrained weights with my own Conv2d functionality which does something different. For this, I wrote a custom class class Conv2d_custom(nn.modules.conv._ConvNd). Then, I have written the following recursive replacement procedure (it needs to be recursive since the network has many submodules and stuff):

def replace_conv(target_conv):
  return Conv2d_custom(target_conv.in_channels, target_conv.out_channels, target_conv.kernel_size,
                       target_conv.stride, target_conv.padding, target_conv.dilation, target_conv.groups,
                       target_conv.bias is not None)

def replace_convs_net(target_network):
  if type(target_network) == nn.Conv2d:

    return replace_conv(target_network)
 
  elif type(target_network) == nn.ModuleList:

    new_modulelist = nn.ModuleList()
    for submodule in target_network.modules():
      if submodule == target_network:
        new_modulelist.append(submodule)
      else:
        new_modulelist.append(replace_convs_net(submodule))
    return new_modulelist

  elif nn.Module in type(target_network).__bases__:

    for attr_str in dir(target_network):
      target_attr = getattr(target_network, attr_str)
      if nn.Module in type(target_attr).__bases__:
        replaced = replace_convs_net(target_attr)
        setattr(target_network, attr_str, replaced)

    return target_network

  else:

    return target_network

However, after I call new_network = replace_convs_net(existing_network), I get the following error during training of the new_network:

  [...]
    f, class_f = self.feats(x)
  File "/home/<user>/.conda/envs/pytorch-3.5/lib/python3.5/site-packages/torch/nn/modules/module.py", line 357, in __call__
    result = self.forward(*input, **kwargs)
  [...]
    out = block(out)
  File "/home/<user>/.conda/envs/pytorch-3.5/lib/python3.5/site-packages/torch/nn/modules/module.py", line 357, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/<user>/.conda/envs/pytorch-3.5/lib/python3.5/site-packages/torch/nn/modules/module.py", line 71, in forward
    raise NotImplementedError
NotImplementedError

It works fine without the replacement.

For the sake of simplicity, you can assume that Conv2d_custom is an exact clone of Conv2d code-wise. Yes, I did that to rule out that the error came from there, and the error doesn’t come from there. The error must come from my replacement function. Does anyone see an issue with my replacement function? Did I destroy some internal states? I’ve hoped that pytorch would allow me this flexibility since it is advertised with “dynamic graphs” and so on.

Note that I cannot just replace the Conv2d with Conv2d_custom in the initial network definition since I first need to load that network as-is with existing weights. The new weights of my Conv2d_custom will then be a non-trivial transformation of the original Conv2d weights. It really needs to be done afterwards on top of the already existing structure.

Pytorch 0.3.1

For those who seek a solution and not snobbish mentoring:

def replace_bn(m, name):
    for attr_str in dir(m):
        target_attr = getattr(m, attr_str)
        if type(target_attr) == torch.nn.BatchNorm2d:
            print('replaced: ', name, attr_str)
            setattr(m, attr_str, SynchronizedBatchNorm2d(target_attr.num_features, target_attr.eps, target_attr.momentum, target_attr.affine))
    for n, ch in m.named_children():
        replace_bn(ch, n)
        
replace_bn(net, "net")
12 Likes

Hi,

I am trying to do sth like you. I mean, I want to use cnn in pytorch, just replace conv1D with convZ(conv generated by Zahra) in which I replace + and * by - and / .
I would appreciate it if you guide me how you replace conv2D with custom convs?

Best Regards

would using m.named_modules() work? I think yes but I guess since replace_bn(ch, n) is already recursing it’s not necessary to recurse again…

it seems it worked in my use case, thanks for the code and for not being snobbish! :laughing:

do the same as the stuff above but use your custom network. Did you try that? Paste your attempts and current errors.

Thanks @IlyaOvodov that works! I made a version of your code with some comments and natural language explanation of it.

def replace_bn(module, name):
    '''
    Recursively put desired batch norm in nn.module module.

    set module = net to start code.
    '''
    # go through all attributes of module nn.module (e.g. network or layer) and put batch norms if present
    for attr_str in dir(module):
        target_attr = getattr(m, attr_str)
        if type(target_attr) == torch.nn.BatchNorm2d:
            print('replaced: ', name, attr_str)
            new_bn = torch.nn.BatchNorm2d(target_attr.num_features, target_attr.eps, target_attr.momentum, target_attr.affine,
                                          track_running_stats=False)
            setattr(module, attr_str, new_bn)

    # iterate through immediate child modules. Note, the recursion is done by our code no need to use named_modules()
    for name, immediate_child_module in module.named_children():
        replace_bn(immediate_child_module, name)

replace_bn(model, 'model')

the crux is that you need to recursively keep changing the layers (mainly because sometimes you will encounter attributes that have modules itself). I think better code than the above would be to add another if statement (after the batch norm) detecting if you have to recurse and recursing if so. The above works to but first changes the batch norm over the outer layer (i.e. the first loop) and then with another loop making sure no other object that should be recursed is missed (and then recursing).

1 Like

Thanks for this answer! You just have a typo where m should be module.

darn it, I probably fixed it in my real code of course. Will fix my reply if the forum allows it (it didn’t, hope ppl see our conv).

1 Like

here is a general function for replacing any layer

def replace_layers(model, old, new):
    for n, module in model.named_children():
        if len(list(module.children())) > 0:
            ## compound module, go inside it
            replace_layers(module, old, new)
            
        if isinstance(module, old):
            ## simple module
            setattr(model, n, new)

replace_layer(model, nn.ReLU, nn.ReLU6())

Note: It uses isinstance() method to check layers. Hence, be cautious while replacing nn.Sequential. I would recommend, its a compound layers, so do it manually.

I struggled with it for a few days. So, I did some digging & wrote a kaggle notebook explaining how different types of layers / modules are accessed in pytorch.

1 Like