Replace layers in model by another and with extra parameters

Hi, I am trying to replace layers in a defined model with another type of layer, but with some extra parameters. An example is something like this:

import torch
import torch.nn as nn

from torchvision.models import resnet18

class ModA(nn.Module):
  def __init__(self, layer):
    super(ModA, self).__init__()
    self.layer = layer
    self.paraA = nn.Parameter(torch.tensor(1.0))

  def forward(self, input):
    return self.layer(self.paramA * input)

net = resnet18(weights="IMAGENET1K_V2")
layer_names = []
layers = []
for n, m in net.named_modules():
  if isinstance(m, nn.Conv2d):
    layer_names.append(n)
    layers.append(ModA(m))

for n, m in zip(layer_names, layers):
  setattr(net, n, m)

net.to('cuda:0')
net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[gpus])

It will report some error messages like this.

If I do not have the extra parameter self.paraA it will work fine (but will print some message like this if I still do this module change). It seems to me that the reason is although I changed the module manually in the model, I have not “inserted” the extra parameter self.paraA into model.parameters. I wonder what is the most appropriate practice to replace layers in a predefined model with custom modules. Thanks.

The reason is replacing layers by setattr with layername does not work for nn.Sequential, so in the original code the nn.Conv2d layers inside nn.Sequential are not replaced correctly, and if print elements in model.named_modules() and model.named_parameters(), there will be an extra module with the same name for each nn.Conv2d in nn.Sequential, and the final model does not have the nn.Conv2d replaced. The extra parameters are not included in nn.named_parameters(), either.

One workaround solution, although not elegant, is as below:
replace this part

for n, m in zip(layer_names, layers):
  setattr(net, n, m)

by

for n, m in zip(layer_names, layers):
  levels = n.split('.')
  if len(levels) > 1:
    mod_ = model
    for l_idx in range(len(levels) - 1):
      if levels[l_idx].isdigit():
        mod_ = mod_[eval(levels[l_idx])]
      else:
        mod_ = getattr(mod_, levels[l_idx])
    setattr(mod_, levels[-1], m)
  else:
    setattr(model, n, m)