Hi,
I want to replace Conv2d modules in an existing complex state-of-the-art neural network with pretrained weights with my own Conv2d functionality which does something different. For this, I wrote a custom class class Conv2d_custom(nn.modules.conv._ConvNd)
. Then, I have written the following recursive replacement procedure (it needs to be recursive since the network has many submodules and stuff):
def replace_conv(target_conv):
return Conv2d_custom(target_conv.in_channels, target_conv.out_channels, target_conv.kernel_size,
target_conv.stride, target_conv.padding, target_conv.dilation, target_conv.groups,
target_conv.bias is not None)
def replace_convs_net(target_network):
if type(target_network) == nn.Conv2d:
return replace_conv(target_network)
elif type(target_network) == nn.ModuleList:
new_modulelist = nn.ModuleList()
for submodule in target_network.modules():
if submodule == target_network:
new_modulelist.append(submodule)
else:
new_modulelist.append(replace_convs_net(submodule))
return new_modulelist
elif nn.Module in type(target_network).__bases__:
for attr_str in dir(target_network):
target_attr = getattr(target_network, attr_str)
if nn.Module in type(target_attr).__bases__:
replaced = replace_convs_net(target_attr)
setattr(target_network, attr_str, replaced)
return target_network
else:
return target_network
However, after I call new_network = replace_convs_net(existing_network)
, I get the following error during training of the new_network
:
[...]
f, class_f = self.feats(x)
File "/home/<user>/.conda/envs/pytorch-3.5/lib/python3.5/site-packages/torch/nn/modules/module.py", line 357, in __call__
result = self.forward(*input, **kwargs)
[...]
out = block(out)
File "/home/<user>/.conda/envs/pytorch-3.5/lib/python3.5/site-packages/torch/nn/modules/module.py", line 357, in __call__
result = self.forward(*input, **kwargs)
File "/home/<user>/.conda/envs/pytorch-3.5/lib/python3.5/site-packages/torch/nn/modules/module.py", line 71, in forward
raise NotImplementedError
NotImplementedError
It works fine without the replacement.
For the sake of simplicity, you can assume that Conv2d_custom
is an exact clone of Conv2d
code-wise. Yes, I did that to rule out that the error came from there, and the error doesn’t come from there. The error must come from my replacement function. Does anyone see an issue with my replacement function? Did I destroy some internal states? I’ve hoped that pytorch would allow me this flexibility since it is advertised with “dynamic graphs” and so on.
Note that I cannot just replace the Conv2d with Conv2d_custom in the initial network definition since I first need to load that network as-is with existing weights. The new weights of my Conv2d_custom will then be a non-trivial transformation of the original Conv2d weights. It really needs to be done afterwards on top of the already existing structure.
Pytorch 0.3.1