As described in this post, where this approach was also posted, I mentioned that this approach is hacky and would work only for simple modules.
If you want to properly swap the normalization layers, you should instead write a custom nn.Module
, derive from the resnet as the base class, and change the normalization layers in the __init__
method.
You could reuse the forward
without changing it.