I want to replace Conv2d modules in an existing complex state-of-the-art neural network with pretrained weights with my own Conv2d functionality which does something different. For this, I wrote a custom class
class Conv2d_custom(nn.modules.conv._ConvNd). Then, I have written the following recursive replacement procedure (it needs to be recursive since the network has many submodules and stuff):
def replace_conv(target_conv): return Conv2d_custom(target_conv.in_channels, target_conv.out_channels, target_conv.kernel_size, target_conv.stride, target_conv.padding, target_conv.dilation, target_conv.groups, target_conv.bias is not None) def replace_convs_net(target_network): if type(target_network) == nn.Conv2d: return replace_conv(target_network) elif type(target_network) == nn.ModuleList: new_modulelist = nn.ModuleList() for submodule in target_network.modules(): if submodule == target_network: new_modulelist.append(submodule) else: new_modulelist.append(replace_convs_net(submodule)) return new_modulelist elif nn.Module in type(target_network).__bases__: for attr_str in dir(target_network): target_attr = getattr(target_network, attr_str) if nn.Module in type(target_attr).__bases__: replaced = replace_convs_net(target_attr) setattr(target_network, attr_str, replaced) return target_network else: return target_network
However, after I call
new_network = replace_convs_net(existing_network), I get the following error during training of the
[...] f, class_f = self.feats(x) File "/home/<user>/.conda/envs/pytorch-3.5/lib/python3.5/site-packages/torch/nn/modules/module.py", line 357, in __call__ result = self.forward(*input, **kwargs) [...] out = block(out) File "/home/<user>/.conda/envs/pytorch-3.5/lib/python3.5/site-packages/torch/nn/modules/module.py", line 357, in __call__ result = self.forward(*input, **kwargs) File "/home/<user>/.conda/envs/pytorch-3.5/lib/python3.5/site-packages/torch/nn/modules/module.py", line 71, in forward raise NotImplementedError NotImplementedError
It works fine without the replacement.
For the sake of simplicity, you can assume that
Conv2d_custom is an exact clone of
Conv2d code-wise. Yes, I did that to rule out that the error came from there, and the error doesn’t come from there. The error must come from my replacement function. Does anyone see an issue with my replacement function? Did I destroy some internal states? I’ve hoped that pytorch would allow me this flexibility since it is advertised with “dynamic graphs” and so on.
Note that I cannot just replace the Conv2d with Conv2d_custom in the initial network definition since I first need to load that network as-is with existing weights. The new weights of my Conv2d_custom will then be a non-trivial transformation of the original Conv2d weights. It really needs to be done afterwards on top of the already existing structure.