Take Resnet50 in torchvision as an example, I want to change all the BatchNorm2d to GroupNorm . How can I implement this efficiently

@ptrblck can you please help me here?

i am unable to do this as well, i tried this :

```
import torchvision.models as models
model = models.resnet18()
#then this :
for name, module in model.named_modules():
if isinstance(module, nn.BatchNorm2d):
# Get current bn layer
bn = getattr(model, name)
# Create new gn layer
gn = nn.GroupNorm(1, bn.num_features)
# Assign gn
print('Swapping {} with {}'.format(bn, gn))
setattr(model, name, gn)
print(model)
```

and it gives this error :

```
Swapping BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) with GroupNorm(1, 64, eps=1e-05, affine=True)
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-26-dc2f23e093cc> in <module>
2 if isinstance(module, nn.BatchNorm2d):
3 # Get current bn layer
----> 4 bn = getattr(model, name)
5 # Create new gn layer
6 gn = nn.GroupNorm(1, bn.num_features)
/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in __getattr__(self, name)
592 return modules[name]
593 raise AttributeError("'{}' object has no attribute '{}'".format(
--> 594 type(self).__name__, name))
595
596 def __setattr__(self, name, value):
AttributeError: 'ResNet' object has no attribute 'layer1.0.bn1'
```

As described in this post, where this approach was also posted, I mentioned that this approach is hacky and would work only for simple modules.

If you want to properly swap the normalization layers, you should instead write a custom `nn.Module`

, derive from the resnet as the base class, and change the normalization layers in the `__init__`

method.

You could reuse the `forward`

without changing it.

@ptrblck if you could give an working example with a torchvision model then it could help a lot

i searched in google,posted in github also in stackoverflow and got answer nowhere so a working example demo could be of great help,thank you anyway

The current `ResNet`

implementation seems to accept different normalization layers as an argument (you might need to update `torchvision`

, if your installed version doesnâ€™t have this argument).

Here is a small code snippet:

```
from torchvision.models import resnet
class MyGroupNorm(nn.Module):
def __init__(self, num_channels):
super(MyGroupNorm, self).__init__()
self.norm = nn.GroupNorm(num_groups=2, num_channels=num_channels,
eps=1e-5, affine=True)
def forward(self, x):
x = self.norm(x)
return x
model = resnet.ResNet(resnet.BasicBlock, [2, 2, 2, 2], num_classes=1000, zero_init_residual=False,
groups=1, width_per_group=64, replace_stride_with_dilation=None,
norm_layer=MyGroupNorm)
print(model)
```

Note that I needed to define the proxy `MyGroupNorm`

module, since the initialization of the `norm_layer`

will only take one argument (as seen in this line of code), while `nn.GroupNorm`

needs two.

@ptrblck thank you a lot.

how do i load weight of pretrained model?

i tried this :

```
from torchvision.models import resnet
class MyGroupNorm(nn.Module):
def __init__(self, num_channels):
super(MyGroupNorm, self).__init__()
self.norm = nn.GroupNorm(num_groups=2, num_channels=num_channels,
eps=1e-5, affine=True)
def forward(self, x):
x = self.norm(x)
return x
model = resnet.ResNet(resnet.BasicBlock, [3, 4, 6, 3], num_classes=5, zero_init_residual=False,
groups=1, width_per_group=64, replace_stride_with_dilation=None,
norm_layer=MyGroupNorm)
print(model)
```

and then this to load pretrained weight of resnext50 :

```
weights_path = '../input/pytorch-se-resnext/se_resnext50_32x4d-a260b3a4.pth'
model.load_state_dict(torch.load(weights_path))
```

but it gave me this error :

```
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-30-2c5332b5515b> in <module>
1 weights_path = '../input/pytorch-se-resnext/se_resnext50_32x4d-a260b3a4.pth'
----> 2 model.load_state_dict(torch.load(weights_path))
/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
845 if len(error_msgs) > 0:
846 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
--> 847 self.__class__.__name__, "\n\t".join(error_msgs)))
848 return _IncompatibleKeys(missing_keys, unexpected_keys)
849
RuntimeError: Error(s) in loading state_dict for ResNet:
Missing key(s) in state_dict: "conv1.weight", "bn1.norm.weight", "bn1.norm.bias", "layer1.0.bn1.norm.weight", "layer1.0.bn1.norm.bias", "layer1.0.bn2.norm.weight", "layer1.0.bn2.norm.bias", "layer1.1.bn1.norm.weight", "layer1.1.bn1.norm.bias", "layer1.1.bn2.norm.weight", "layer1.1.bn2.norm.bias", "layer1.2.bn1.norm.weight", "layer1.2.bn1.norm.bias", "layer1.2.bn2.norm.weight", "layer1.2.bn2.norm.bias", "layer2.0.bn1.norm.weight", "layer2.0.bn1.norm.bias", "layer2.0.bn2.norm.weight", "layer2.0.bn2.norm.bias", "layer2.0.downsample.1.norm.weight", "layer2.0.downsample.1.norm.bias", "layer2.1.bn1.norm.weight", "layer2.1.bn1.norm.bias", "layer2.1.bn2.norm.weight", "layer2.1.bn2.norm.bias", "layer2.2.bn1.norm.weight", "layer2.2.bn1.norm.bias", "layer2.2.bn2.norm.weight", "layer2.2.bn2.norm.bias", "layer2.3.bn1.norm.weight", "layer2.3.bn1.norm.bias", "layer2.3.bn2.norm.weight", "layer2.3.bn2.norm.bias", "layer3.0.bn1.norm.weight", "layer3.0.bn1.norm.bias", "layer3.0.bn2.norm.weight", "layer3.0.bn2.norm.bias", "layer3.0.downsample.1.norm.weight", "layer3.0.downsample.1.norm.bias", "layer3.1.bn1.norm.weight", "layer3.1.bn1.norm.bias", "layer3.1.bn2.norm.weight", "layer3.1.bn2.norm.bias", "layer3.2.bn1.norm.weight", "layer3.2.bn1.norm.bias", "layer3.2.bn2.norm.weight", "layer3.2.bn2.norm.bias", "layer3.3.bn1.norm.weight", "layer3.3.bn1.norm.bias", "layer3.3.bn2.norm.weight", "layer3.3.bn2.norm.bias", "layer3.4.bn1.norm.weight", "layer3.4.bn1.norm.bias", "layer3.4.bn2.norm.weight", "layer3.4.bn2.norm.bias", "layer3.5.bn1.norm.weight", "layer3.5.bn1.norm.bias", "layer3.5.bn2.norm.weight", "layer3.5.bn2.norm.bias", "layer4.0.bn1.norm.weight", "layer4.0.bn1.norm.bias", "layer4.0.bn2.norm.weight", "layer4.0.bn2.norm.bias", "layer4.0.downsample.1.norm.weight", "layer4.0.downsample.1.norm.bias", "layer4.1.bn1.norm.weight", "layer4.1.bn1.norm.bias", "layer4.1.bn2.norm.weight", "layer4.1.bn2.norm.bias", "layer4.2.bn1.norm.weight", "layer4.2.bn1.norm.bias", "layer4.2.bn2.norm.weight", "layer4.2.bn2.norm.bias", "fc.weight", "fc.bias".
Unexpected key(s) in state_dict: "layer0.conv1.weight", "layer0.bn1.weight", "layer0.bn1.bias", "layer0.bn1.running_mean", "layer0.bn1.running_var", "last_linear.weight", "last_linear.bias", "layer1.0.conv3.weight", "layer1.0.bn3.weight", "layer1.0.bn3.bias", "layer1.0.bn3.running_mean", "layer1.0.bn3.running_var", "layer1.0.se_module.fc1.weight", "layer1.0.se_module.fc1.bias", "layer1.0.se_module.fc2.weight", "layer1.0.se_module.fc2.bias", "layer1.0.downsample.0.weight", "layer1.0.downsample.1.weight", "layer1.0.downsample.1.bias", "layer1.0.downsample.1.running_mean", "layer1.0.downsample.1.running_var", "layer1.0.bn1.weight", "layer1.0.bn1.bias", "layer1.0.bn1.running_mean", "layer1.0.bn1.running_var", "layer1.0.bn2.weight", "layer1.0.bn2.bias", "layer1.0.bn2.running_mean", "layer1.0.bn2.running_var", "layer1.1.conv3.weight", "layer1.1.bn3.weight", "layer1.1.bn3.bias", "layer1.1.bn3.running_mean", "layer1.1.bn3.running_var", "layer1.1.se_module.fc1.weight", "layer1.1.se_module.fc1.bias", "layer1.1.se_module.fc2.weight", "layer1.1.se_module.fc2.bias", "layer1.1.bn1.weight", "layer1.1.bn1.bias", "layer1.1.bn1.running_mean", "layer1.1.bn1.running_var", "layer1.1.bn2.weight", "layer1.1.bn2.bias", "layer1.1.bn2.running_mean", "layer1.1.bn2.running_var", "layer1.2.conv3.weight", "layer1.2.bn3.weight", "layer1.2.bn3.bias", "layer1.2.bn3.running_mean", "layer1.2.bn3.running_var", "layer1.2.se_module.fc1.weight", "layer1.2.se_module.fc1.bias", "layer1.2.se_module.fc2.weight", "layer1.2.se_module.fc2.bias", "layer1.2.bn1.weight", "layer1.2.bn1.bias", "layer1.2.bn1.running_mean", "layer1.2.bn1.running_var", "layer1.2.bn2.weight", "layer1.2.bn2.bias", "layer1.2.bn2.running_mean", "layer1.2.bn2.running_var", "layer2.0.conv3.weight", "layer2.0.bn3.weight", "layer2.0.bn3.bias", "layer2.0.bn3.running_mean", "layer2.0.bn3.running_var", "layer2.0.se_module.fc1.weight", "layer2.0.se_module.fc1.bias", "layer2.0.se_module.fc2.weight", "layer2.0.se_module.fc2.bias", "layer2.0.bn1.weight", "layer2.0.bn1.bias", "layer2.0.bn1.running_mean", "layer2.0.bn1.running_var", "layer2.0.bn2.weight", "layer2.0.bn2.bias", "layer2.0.bn2.running_mean", "layer2.0.bn2.running_var", "layer2.0.downsample.1.weight", "layer2.0.downsample.1.bias", "layer2.0.downsample.1.running_mean", "layer2.0.downsample.1.running_var", "layer2.1.conv3.weight", "layer2.1.bn3.weight", "layer2.1.bn3.bias", "layer2.1.bn3.running_mean", "layer2.1.bn3.running_var", "layer2.1.se_module.fc1.weight", "layer2.1.se_module.fc1.bias", "layer2.1.se_module.fc2.weight", "layer2.1.se_module.fc2.bias", "layer2.1.bn1.weight", "layer2.1.bn1.bias", "layer2.1.bn1.running_mean", "layer2.1.bn1.running_var", "layer2.1.bn2.weight", "layer2.1.bn2.bias", "layer2.1.bn2.running_mean", "layer2.1.bn2.running_var", "layer2.2.conv3.weight", "layer2.2.bn3.weight", "layer2.2.bn3.bias", "layer2.2.bn3.running_mean", "layer2.2.bn3.running_var", "layer2.2.se_module.fc1.weight", "layer2.2.se_module.fc1.bias", "layer2.2.se_module.fc2.weight", "layer2.2.se_module.fc2.bias", "layer2.2.bn1.weight", "layer2.2.bn1.bias", "layer2.2.bn1.running_mean", "layer2.2.bn1.running_var", "layer2.2.bn2.weight", "layer2.2.bn2.bias", "layer2.2.bn2.running_mean", "layer2.2.bn2.running_var", "layer2.3.conv3.weight", "layer2.3.bn3.weight", "layer2.3.bn3.bias", "layer2.3.bn3.running_mean", "layer2.3.bn3.running_var", "layer2.3.se_module.fc1.weight", "layer2.3.se_module.fc1.bias", "layer2.3.se_module.fc2.weight", "layer2.3.se_module.fc2.bias", "layer2.3.bn1.weight", "layer2.3.bn1.bias", "layer2.3.bn1.running_mean", "layer2.3.bn1.running_var", "layer2.3.bn2.weight", "layer2.3.bn2.bias", "layer2.3.bn2.running_mean", "layer2.3.bn2.running_var", "layer3.0.conv3.weight", "layer3.0.bn3.weight", "layer3.0.bn3.bias", "layer3.0.bn3.running_mean", "layer3.0.bn3.running_var", "layer3.0.se_module.fc1.weight", "layer3.0.se_module.fc1.bias", "layer3.0.se_module.fc2.weight", "layer3.0.se_module.fc2.bias", "layer3.0.bn1.weight", "layer3.0.bn1.bias", "layer3.0.bn1.running_mean", "layer3.0.bn1.running_var", "layer3.0.bn2.weight", "layer3.0.bn2.bias", "layer3.0.bn2.running_mean", "layer3.0.bn2.running_var", "layer3.0.downsample.1.weight", "layer3.0.downsample.1.bias", "layer3.0.downsample.1.running_mean", "layer3.0.downsample.1.running_var", "layer3.1.conv3.weight", "layer3.1.bn3.weight", "layer3.1.bn3.bias", "layer3.1.bn3.running_mean", "layer3.1.bn3.running_var", "layer3.1.se_module.fc1.weight", "layer3.1.se_module.fc1.bias", "layer3.1.se_module.fc2.weight", "layer3.1.se_module.fc2.bias", "layer3.1.bn1.weight", "layer3.1.bn1.bias", "layer3.1.bn1.running_mean", "layer3.1.bn1.running_var", "layer3.1.bn2.weight", "layer3.1.bn2.bias", "layer3.1.bn2.running_mean", "layer3.1.bn2.running_var", "layer3.2.conv3.weight", "layer3.2.bn3.weight", "layer3.2.bn3.bias", "layer3.2.bn3.running_mean", "layer3.2.bn3.running_var", "layer3.2.se_module.fc1.weight", "layer3.2.se_module.fc1.bias", "layer3.2.se_module.fc2.weight", "layer3.2.se_module.fc2.bias", "layer3.2.bn1.weight", "layer3.2.bn1.bias", "layer3.2.bn1.running_mean", "layer3.2.bn1.running_var", "layer3.2.bn2.weight", "layer3.2.bn2.bias", "layer3.2.bn2.running_mean", "layer3.2.bn2.running_var", "layer3.3.conv3.weight", "layer3.3.bn3.weight", "layer3.3.bn3.bias", "layer3.3.bn3.running_mean", "layer3.3.bn3.running_var", "layer3.3.se_module.fc1.weight", "layer3.3.se_module.fc1.bias", "layer3.3.se_module.fc2.weight", "layer3.3.se_module.fc2.bias", "layer3.3.bn1.weight", "layer3.3.bn1.bias", "layer3.3.bn1.running_mean", "layer3.3.bn1.running_var", "layer3.3.bn2.weight", "layer3.3.bn2.bias", "layer3.3.bn2.running_mean", "layer3.3.bn2.running_var", "layer3.4.conv3.weight", "layer3.4.bn3.weight", "layer3.4.bn3.bias", "layer3.4.bn3.running_mean", "layer3.4.bn3.running_var", "layer3.4.se_module.fc1.weight", "layer3.4.se_module.fc1.bias", "layer3.4.se_module.fc2.weight", "layer3.4.se_module.fc2.bias", "layer3.4.bn1.weight", "layer3.4.bn1.bias", "layer3.4.bn1.running_mean", "layer3.4.bn1.running_var", "layer3.4.bn2.weight", "layer3.4.bn2.bias", "layer3.4.bn2.running_mean", "layer3.4.bn2.running_var", "layer3.5.conv3.weight", "layer3.5.bn3.weight", "layer3.5.bn3.bias", "layer3.5.bn3.running_mean", "layer3.5.bn3.running_var", "layer3.5.se_module.fc1.weight", "layer3.5.se_module.fc1.bias", "layer3.5.se_module.fc2.weight", "layer3.5.se_module.fc2.bias", "layer3.5.bn1.weight", "layer3.5.bn1.bias", "layer3.5.bn1.running_mean", "layer3.5.bn1.running_var", "layer3.5.bn2.weight", "layer3.5.bn2.bias", "layer3.5.bn2.running_mean", "layer3.5.bn2.running_var", "layer4.0.conv3.weight", "layer4.0.bn3.weight", "layer4.0.bn3.bias", "layer4.0.bn3.running_mean", "layer4.0.bn3.running_var", "layer4.0.se_module.fc1.weight", "layer4.0.se_module.fc1.bias", "layer4.0.se_module.fc2.weight", "layer4.0.se_module.fc2.bias", "layer4.0.bn1.weight", "layer4.0.bn1.bias", "layer4.0.bn1.running_mean", "layer4.0.bn1.running_var", "layer4.0.bn2.weight", "layer4.0.bn2.bias", "layer4.0.bn2.running_mean", "layer4.0.bn2.running_var", "layer4.0.downsample.1.weight", "layer4.0.downsample.1.bias", "layer4.0.downsample.1.running_mean", "layer4.0.downsample.1.running_var", "layer4.1.conv3.weight", "layer4.1.bn3.weight", "layer4.1.bn3.bias", "layer4.1.bn3.running_mean", "layer4.1.bn3.running_var", "layer4.1.se_module.fc1.weight", "layer4.1.se_module.fc1.bias", "layer4.1.se_module.fc2.weight", "layer4.1.se_module.fc2.bias", "layer4.1.bn1.weight", "layer4.1.bn1.bias", "layer4.1.bn1.running_mean", "layer4.1.bn1.running_var", "layer4.1.bn2.weight", "layer4.1.bn2.bias", "layer4.1.bn2.running_mean", "layer4.1.bn2.running_var", "layer4.2.conv3.weight", "layer4.2.bn3.weight", "layer4.2.bn3.bias", "layer4.2.bn3.running_mean", "layer4.2.bn3.running_var", "layer4.2.se_module.fc1.weight", "layer4.2.se_module.fc1.bias", "layer4.2.se_module.fc2.weight", "layer4.2.se_module.fc2.bias", "layer4.2.bn1.weight", "layer4.2.bn1.bias", "layer4.2.bn1.running_mean", "layer4.2.bn1.running_var", "layer4.2.bn2.weight", "layer4.2.bn2.bias", "layer4.2.bn2.running_mean", "layer4.2.bn2.running_var".
size mismatch for layer1.0.conv1.weight: copying a param with shape torch.Size([128, 64, 1, 1]) from checkpoint, the shape in current model is torch.Size([64, 64, 3, 3]).
size mismatch for layer1.0.conv2.weight: copying a param with shape torch.Size([128, 4, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 64, 3, 3]).
size mismatch for layer1.1.conv1.weight: copying a param with shape torch.Size([128, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([64, 64, 3, 3]).
size mismatch for layer1.1.conv2.weight: copying a param with shape torch.Size([128, 4, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 64, 3, 3]).
size mismatch for layer1.2.conv1.weight: copying a param with shape torch.Size([128, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([64, 64, 3, 3]).
size mismatch for layer1.2.conv2.weight: copying a param with shape torch.Size([128, 4, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 64, 3, 3]).
size mismatch for layer2.0.conv1.weight: copying a param with shape torch.Size([256, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([128, 64, 3, 3]).
size mismatch for layer2.0.conv2.weight: copying a param with shape torch.Size([256, 8, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 128, 3, 3]).
size mismatch for layer2.0.downsample.0.weight: copying a param with shape torch.Size([512, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([128, 64, 1, 1]).
size mismatch for layer2.1.conv1.weight: copying a param with shape torch.Size([256, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([128, 128, 3, 3]).
size mismatch for layer2.1.conv2.weight: copying a param with shape torch.Size([256, 8, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 128, 3, 3]).
size mismatch for layer2.2.conv1.weight: copying a param with shape torch.Size([256, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([128, 128, 3, 3]).
size mismatch for layer2.2.conv2.weight: copying a param with shape torch.Size([256, 8, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 128, 3, 3]).
size mismatch for layer2.3.conv1.weight: copying a param with shape torch.Size([256, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([128, 128, 3, 3]).
size mismatch for layer2.3.conv2.weight: copying a param with shape torch.Size([256, 8, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 128, 3, 3]).
size mismatch for layer3.0.conv1.weight: copying a param with shape torch.Size([512, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([256, 128, 3, 3]).
size mismatch for layer3.0.conv2.weight: copying a param with shape torch.Size([512, 16, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 256, 3, 3]).
size mismatch for layer3.0.downsample.0.weight: copying a param with shape torch.Size([1024, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([256, 128, 1, 1]).
size mismatch for layer3.1.conv1.weight: copying a param with shape torch.Size([512, 1024, 1, 1]) from checkpoint, the shape in current model is torch.Size([256, 256, 3, 3]).
size mismatch for layer3.1.conv2.weight: copying a param with shape torch.Size([512, 16, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 256, 3, 3]).
size mismatch for layer3.2.conv1.weight: copying a param with shape torch.Size([512, 1024, 1, 1]) from checkpoint, the shape in current model is torch.Size([256, 256, 3, 3]).
size mismatch for layer3.2.conv2.weight: copying a param with shape torch.Size([512, 16, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 256, 3, 3]).
size mismatch for layer3.3.conv1.weight: copying a param with shape torch.Size([512, 1024, 1, 1]) from checkpoint, the shape in current model is torch.Size([256, 256, 3, 3]).
size mismatch for layer3.3.conv2.weight: copying a param with shape torch.Size([512, 16, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 256, 3, 3]).
size mismatch for layer3.4.conv1.weight: copying a param with shape torch.Size([512, 1024, 1, 1]) from checkpoint, the shape in current model is torch.Size([256, 256, 3, 3]).
size mismatch for layer3.4.conv2.weight: copying a param with shape torch.Size([512, 16, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 256, 3, 3]).
size mismatch for layer3.5.conv1.weight: copying a param with shape torch.Size([512, 1024, 1, 1]) from checkpoint, the shape in current model is torch.Size([256, 256, 3, 3]).
size mismatch for layer3.5.conv2.weight: copying a param with shape torch.Size([512, 16, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 256, 3, 3]).
size mismatch for layer4.0.conv1.weight: copying a param with shape torch.Size([1024, 1024, 1, 1]) from checkpoint, the shape in current model is torch.Size([512, 256, 3, 3]).
size mismatch for layer4.0.conv2.weight: copying a param with shape torch.Size([1024, 32, 3, 3]) from checkpoint, the shape in current model is torch.Size([512, 512, 3, 3]).
size mismatch for layer4.0.downsample.0.weight: copying a param with shape torch.Size([2048, 1024, 1, 1]) from checkpoint, the shape in current model is torch.Size([512, 256, 1, 1]).
size mismatch for layer4.1.conv1.weight: copying a param with shape torch.Size([1024, 2048, 1, 1]) from checkpoint, the shape in current model is torch.Size([512, 512, 3, 3]).
size mismatch for layer4.1.conv2.weight: copying a param with shape torch.Size([1024, 32, 3, 3]) from checkpoint, the shape in current model is torch.Size([512, 512, 3, 3]).
size mismatch for layer4.2.conv1.weight: copying a param with shape torch.Size([1024, 2048, 1, 1]) from checkpoint, the shape in current model is torch.Size([512, 512, 3, 3]).
size mismatch for layer4.2.conv2.weight: copying a param with shape torch.Size([1024, 32, 3, 3]) from checkpoint, the shape in current model is torch.Size([512, 512, 3, 3]).
```

You are trying to load the ResNext `state_dict`

into a ResNet, which wonâ€™t work out of the box.

To load the pretrained `resnet18`

parameters, you could use:

```
sd = models.resnet18(pretrained=True).state_dict()
model.load_state_dict(sd, strict=False)
for key in sd:
ref = sd[key]
actual = model.state_dict().get(key, None)
if actual is not None:
print((ref - actual).abs().max())
else:
print(key, ', not found in model')
```

The last loop is just a quick test to show, that the newly added `nn.GroupNorm`

layers wonâ€™t be initialized.

thank you @ptrblck

your code works for any resnet

however i want to know how do i change this line of code : model = resnet.ResNet(resnet.BasicBlock, [3, 4, 6, 3], num_classes=5, zero_init_residual=False,

groups=1, width_per_group=64, replace_stride_with_dilation=None,

norm_layer=MyGroupNorm)

to load resnext50?

i canâ€™t pass norm_layer=MyGroupNorm for resnext,if you show an example for resnext50 then it could be great! thank you a lot

@ptrblck i need 1 help from you!

i tried your bn to gn swapping code and it works but here is what happening with me :

- without groupnormalization swapping my model works fine
- when i swap bn layers with gn(like you did in above message) i see qwk score getting 0.0 sometimes but with batchnorm layers i never get qwk score,why is this happening?

let me repeat again : â€śeverything works fine when i donâ€™t swap bn layers with gn layersâ€ť

when i use gn layers my model starts performing very badly but groupnorm paper is telling different story!! what could be the reason for this? any suggestion or groupnorm modification for solving this issue?

I donâ€™t know, why `GroupNorm`

layers might decrease the model performance.

How did you swap them and verify that the model is correctly using these layers?

@ptrblck i just changed my efficientnet model with your model like you posted above,

my previous model that was working fine :

```
class enetv2(nn.Module):
def __init__(self, backbone, out_dim):
super(enetv2, self).__init__()
self.enet = enet.EfficientNet.from_name(backbone)
self.enet.load_state_dict(torch.load(pretrained_model[backbone]))
self.myfc = nn.Linear(self.enet._fc.in_features, out_dim)
self.enet._fc = nn.Identity()
def extract(self, x):
return self.enet(x)
def forward(self, x):
x = self.extract(x)
x = self.myfc(x)
return x
```

here out_dim = 5

your model :

```
from torchvision.models import resnet
class MyGroupNorm(nn.Module):
def __init__(self, num_channels):
super(MyGroupNorm, self).__init__()
self.norm = nn.GroupNorm(num_groups=2, num_channels=num_channels,
eps=1e-5, affine=True)
def forward(self, x):
x = self.norm(x)
return x
model = resnet.ResNet(resnet.BasicBlock, [2, 2, 2, 2], num_classes=5, zero_init_residual=False,
groups=1, width_per_group=64, replace_stride_with_dilation=None,
norm_layer=MyGroupNorm)
print(model)
```

i just simply changed my efficientnet model with your model so that i can use groupnorm and everything else was as before!

Hi

when I use your method I get the following error:

```
Traceback (most recent call last):
File "train_discriminator.py", line 170, in <module>
norm_layer=MyGroupNorm)
File "/alto/nima/torch-env/lib/python3.6/site-packages/torchvision/models/resnet.py", line 192, in __init__
nn.init.constant_(m.weight, 1)
File "/alto/nima/torch-env/lib/python3.6/site-packages/torch/nn/init.py", line 176, in constant_
return _no_grad_fill_(tensor, val)
File "/alto/nima/torch-env/lib/python3.6/site-packages/torch/nn/init.py", line 59, in _no_grad_fill_
return tensor.fill_(val)
AttributeError: 'NoneType' object has no attribute 'fill_'
```

You would have to handle `MyGroupNorm`

differently, as this custom modules doesnâ€™t contain the `.weight`

attribute, but instead registers `.norm`

, which has the `weight`

and `bias`

, if `affine`

was set to `True`

.

As a quick fix, you could check, if `m.weight`

is not `None`

and initialize it then.