How to change all BN layers to GN

Take Resnet50 in torchvision as an example, I want to change all the BatchNorm2d to GroupNorm . How can I implement this efficiently

@ptrblck can you please help me here?
i am unable to do this as well, i tried this :

import torchvision.models as models
model  = models.resnet18()

#then this : 

for name, module in model.named_modules():
    if isinstance(module, nn.BatchNorm2d):
        # Get current bn layer
        bn = getattr(model, name)
        # Create new gn layer
        gn = nn.GroupNorm(1, bn.num_features)
        # Assign gn
        print('Swapping {} with {}'.format(bn, gn))
        setattr(model, name, gn)

print(model)

and it gives this error :

Swapping BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) with GroupNorm(1, 64, eps=1e-05, affine=True)
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-26-dc2f23e093cc> in <module>
      2     if isinstance(module, nn.BatchNorm2d):
      3         # Get current bn layer
----> 4         bn = getattr(model, name)
      5         # Create new gn layer
      6         gn = nn.GroupNorm(1, bn.num_features)

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in __getattr__(self, name)
    592                 return modules[name]
    593         raise AttributeError("'{}' object has no attribute '{}'".format(
--> 594             type(self).__name__, name))
    595 
    596     def __setattr__(self, name, value):

AttributeError: 'ResNet' object has no attribute 'layer1.0.bn1'

As described in this post, where this approach was also posted, I mentioned that this approach is hacky and would work only for simple modules.
If you want to properly swap the normalization layers, you should instead write a custom nn.Module, derive from the resnet as the base class, and change the normalization layers in the __init__ method.

You could reuse the forward without changing it.

@ptrblck if you could give an working example with a torchvision model then it could help a lot
i searched in google,posted in github also in stackoverflow and got answer nowhere so a working example demo could be of great help,thank you anyway

The current ResNet implementation seems to accept different normalization layers as an argument (you might need to update torchvision, if your installed version doesn’t have this argument).
Here is a small code snippet:

from torchvision.models import resnet


class MyGroupNorm(nn.Module):
    def __init__(self, num_channels):
        super(MyGroupNorm, self).__init__()
        self.norm = nn.GroupNorm(num_groups=2, num_channels=num_channels,
                                 eps=1e-5, affine=True)
    
    def forward(self, x):
        x = self.norm(x)
        return x

model = resnet.ResNet(resnet.BasicBlock, [2, 2, 2, 2], num_classes=1000, zero_init_residual=False,
                 groups=1, width_per_group=64, replace_stride_with_dilation=None,
                 norm_layer=MyGroupNorm)
print(model)

Note that I needed to define the proxy MyGroupNorm module, since the initialization of the norm_layer will only take one argument (as seen in this line of code), while nn.GroupNorm needs two.

1 Like

@ptrblck thank you a lot.
how do i load weight of pretrained model?
i tried this :

from torchvision.models import resnet


class MyGroupNorm(nn.Module):
    def __init__(self, num_channels):
        super(MyGroupNorm, self).__init__()
        self.norm = nn.GroupNorm(num_groups=2, num_channels=num_channels,
                                 eps=1e-5, affine=True)
    
    def forward(self, x):
        x = self.norm(x)
        return x

model = resnet.ResNet(resnet.BasicBlock, [3, 4, 6, 3], num_classes=5, zero_init_residual=False,
                 groups=1, width_per_group=64, replace_stride_with_dilation=None,
                 norm_layer=MyGroupNorm)
print(model)

and then this to load pretrained weight of resnext50 :


weights_path = '../input/pytorch-se-resnext/se_resnext50_32x4d-a260b3a4.pth'
model.load_state_dict(torch.load(weights_path))

but it gave me this error :

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-30-2c5332b5515b> in <module>
      1 weights_path = '../input/pytorch-se-resnext/se_resnext50_32x4d-a260b3a4.pth'
----> 2 model.load_state_dict(torch.load(weights_path))

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
    845         if len(error_msgs) > 0:
    846             raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
--> 847                                self.__class__.__name__, "\n\t".join(error_msgs)))
    848         return _IncompatibleKeys(missing_keys, unexpected_keys)
    849 

RuntimeError: Error(s) in loading state_dict for ResNet:
	Missing key(s) in state_dict: "conv1.weight", "bn1.norm.weight", "bn1.norm.bias", "layer1.0.bn1.norm.weight", "layer1.0.bn1.norm.bias", "layer1.0.bn2.norm.weight", "layer1.0.bn2.norm.bias", "layer1.1.bn1.norm.weight", "layer1.1.bn1.norm.bias", "layer1.1.bn2.norm.weight", "layer1.1.bn2.norm.bias", "layer1.2.bn1.norm.weight", "layer1.2.bn1.norm.bias", "layer1.2.bn2.norm.weight", "layer1.2.bn2.norm.bias", "layer2.0.bn1.norm.weight", "layer2.0.bn1.norm.bias", "layer2.0.bn2.norm.weight", "layer2.0.bn2.norm.bias", "layer2.0.downsample.1.norm.weight", "layer2.0.downsample.1.norm.bias", "layer2.1.bn1.norm.weight", "layer2.1.bn1.norm.bias", "layer2.1.bn2.norm.weight", "layer2.1.bn2.norm.bias", "layer2.2.bn1.norm.weight", "layer2.2.bn1.norm.bias", "layer2.2.bn2.norm.weight", "layer2.2.bn2.norm.bias", "layer2.3.bn1.norm.weight", "layer2.3.bn1.norm.bias", "layer2.3.bn2.norm.weight", "layer2.3.bn2.norm.bias", "layer3.0.bn1.norm.weight", "layer3.0.bn1.norm.bias", "layer3.0.bn2.norm.weight", "layer3.0.bn2.norm.bias", "layer3.0.downsample.1.norm.weight", "layer3.0.downsample.1.norm.bias", "layer3.1.bn1.norm.weight", "layer3.1.bn1.norm.bias", "layer3.1.bn2.norm.weight", "layer3.1.bn2.norm.bias", "layer3.2.bn1.norm.weight", "layer3.2.bn1.norm.bias", "layer3.2.bn2.norm.weight", "layer3.2.bn2.norm.bias", "layer3.3.bn1.norm.weight", "layer3.3.bn1.norm.bias", "layer3.3.bn2.norm.weight", "layer3.3.bn2.norm.bias", "layer3.4.bn1.norm.weight", "layer3.4.bn1.norm.bias", "layer3.4.bn2.norm.weight", "layer3.4.bn2.norm.bias", "layer3.5.bn1.norm.weight", "layer3.5.bn1.norm.bias", "layer3.5.bn2.norm.weight", "layer3.5.bn2.norm.bias", "layer4.0.bn1.norm.weight", "layer4.0.bn1.norm.bias", "layer4.0.bn2.norm.weight", "layer4.0.bn2.norm.bias", "layer4.0.downsample.1.norm.weight", "layer4.0.downsample.1.norm.bias", "layer4.1.bn1.norm.weight", "layer4.1.bn1.norm.bias", "layer4.1.bn2.norm.weight", "layer4.1.bn2.norm.bias", "layer4.2.bn1.norm.weight", "layer4.2.bn1.norm.bias", "layer4.2.bn2.norm.weight", "layer4.2.bn2.norm.bias", "fc.weight", "fc.bias". 
	Unexpected key(s) in state_dict: "layer0.conv1.weight", "layer0.bn1.weight", "layer0.bn1.bias", "layer0.bn1.running_mean", "layer0.bn1.running_var", "last_linear.weight", "last_linear.bias", "layer1.0.conv3.weight", "layer1.0.bn3.weight", "layer1.0.bn3.bias", "layer1.0.bn3.running_mean", "layer1.0.bn3.running_var", "layer1.0.se_module.fc1.weight", "layer1.0.se_module.fc1.bias", "layer1.0.se_module.fc2.weight", "layer1.0.se_module.fc2.bias", "layer1.0.downsample.0.weight", "layer1.0.downsample.1.weight", "layer1.0.downsample.1.bias", "layer1.0.downsample.1.running_mean", "layer1.0.downsample.1.running_var", "layer1.0.bn1.weight", "layer1.0.bn1.bias", "layer1.0.bn1.running_mean", "layer1.0.bn1.running_var", "layer1.0.bn2.weight", "layer1.0.bn2.bias", "layer1.0.bn2.running_mean", "layer1.0.bn2.running_var", "layer1.1.conv3.weight", "layer1.1.bn3.weight", "layer1.1.bn3.bias", "layer1.1.bn3.running_mean", "layer1.1.bn3.running_var", "layer1.1.se_module.fc1.weight", "layer1.1.se_module.fc1.bias", "layer1.1.se_module.fc2.weight", "layer1.1.se_module.fc2.bias", "layer1.1.bn1.weight", "layer1.1.bn1.bias", "layer1.1.bn1.running_mean", "layer1.1.bn1.running_var", "layer1.1.bn2.weight", "layer1.1.bn2.bias", "layer1.1.bn2.running_mean", "layer1.1.bn2.running_var", "layer1.2.conv3.weight", "layer1.2.bn3.weight", "layer1.2.bn3.bias", "layer1.2.bn3.running_mean", "layer1.2.bn3.running_var", "layer1.2.se_module.fc1.weight", "layer1.2.se_module.fc1.bias", "layer1.2.se_module.fc2.weight", "layer1.2.se_module.fc2.bias", "layer1.2.bn1.weight", "layer1.2.bn1.bias", "layer1.2.bn1.running_mean", "layer1.2.bn1.running_var", "layer1.2.bn2.weight", "layer1.2.bn2.bias", "layer1.2.bn2.running_mean", "layer1.2.bn2.running_var", "layer2.0.conv3.weight", "layer2.0.bn3.weight", "layer2.0.bn3.bias", "layer2.0.bn3.running_mean", "layer2.0.bn3.running_var", "layer2.0.se_module.fc1.weight", "layer2.0.se_module.fc1.bias", "layer2.0.se_module.fc2.weight", "layer2.0.se_module.fc2.bias", "layer2.0.bn1.weight", "layer2.0.bn1.bias", "layer2.0.bn1.running_mean", "layer2.0.bn1.running_var", "layer2.0.bn2.weight", "layer2.0.bn2.bias", "layer2.0.bn2.running_mean", "layer2.0.bn2.running_var", "layer2.0.downsample.1.weight", "layer2.0.downsample.1.bias", "layer2.0.downsample.1.running_mean", "layer2.0.downsample.1.running_var", "layer2.1.conv3.weight", "layer2.1.bn3.weight", "layer2.1.bn3.bias", "layer2.1.bn3.running_mean", "layer2.1.bn3.running_var", "layer2.1.se_module.fc1.weight", "layer2.1.se_module.fc1.bias", "layer2.1.se_module.fc2.weight", "layer2.1.se_module.fc2.bias", "layer2.1.bn1.weight", "layer2.1.bn1.bias", "layer2.1.bn1.running_mean", "layer2.1.bn1.running_var", "layer2.1.bn2.weight", "layer2.1.bn2.bias", "layer2.1.bn2.running_mean", "layer2.1.bn2.running_var", "layer2.2.conv3.weight", "layer2.2.bn3.weight", "layer2.2.bn3.bias", "layer2.2.bn3.running_mean", "layer2.2.bn3.running_var", "layer2.2.se_module.fc1.weight", "layer2.2.se_module.fc1.bias", "layer2.2.se_module.fc2.weight", "layer2.2.se_module.fc2.bias", "layer2.2.bn1.weight", "layer2.2.bn1.bias", "layer2.2.bn1.running_mean", "layer2.2.bn1.running_var", "layer2.2.bn2.weight", "layer2.2.bn2.bias", "layer2.2.bn2.running_mean", "layer2.2.bn2.running_var", "layer2.3.conv3.weight", "layer2.3.bn3.weight", "layer2.3.bn3.bias", "layer2.3.bn3.running_mean", "layer2.3.bn3.running_var", "layer2.3.se_module.fc1.weight", "layer2.3.se_module.fc1.bias", "layer2.3.se_module.fc2.weight", "layer2.3.se_module.fc2.bias", "layer2.3.bn1.weight", "layer2.3.bn1.bias", "layer2.3.bn1.running_mean", "layer2.3.bn1.running_var", "layer2.3.bn2.weight", "layer2.3.bn2.bias", "layer2.3.bn2.running_mean", "layer2.3.bn2.running_var", "layer3.0.conv3.weight", "layer3.0.bn3.weight", "layer3.0.bn3.bias", "layer3.0.bn3.running_mean", "layer3.0.bn3.running_var", "layer3.0.se_module.fc1.weight", "layer3.0.se_module.fc1.bias", "layer3.0.se_module.fc2.weight", "layer3.0.se_module.fc2.bias", "layer3.0.bn1.weight", "layer3.0.bn1.bias", "layer3.0.bn1.running_mean", "layer3.0.bn1.running_var", "layer3.0.bn2.weight", "layer3.0.bn2.bias", "layer3.0.bn2.running_mean", "layer3.0.bn2.running_var", "layer3.0.downsample.1.weight", "layer3.0.downsample.1.bias", "layer3.0.downsample.1.running_mean", "layer3.0.downsample.1.running_var", "layer3.1.conv3.weight", "layer3.1.bn3.weight", "layer3.1.bn3.bias", "layer3.1.bn3.running_mean", "layer3.1.bn3.running_var", "layer3.1.se_module.fc1.weight", "layer3.1.se_module.fc1.bias", "layer3.1.se_module.fc2.weight", "layer3.1.se_module.fc2.bias", "layer3.1.bn1.weight", "layer3.1.bn1.bias", "layer3.1.bn1.running_mean", "layer3.1.bn1.running_var", "layer3.1.bn2.weight", "layer3.1.bn2.bias", "layer3.1.bn2.running_mean", "layer3.1.bn2.running_var", "layer3.2.conv3.weight", "layer3.2.bn3.weight", "layer3.2.bn3.bias", "layer3.2.bn3.running_mean", "layer3.2.bn3.running_var", "layer3.2.se_module.fc1.weight", "layer3.2.se_module.fc1.bias", "layer3.2.se_module.fc2.weight", "layer3.2.se_module.fc2.bias", "layer3.2.bn1.weight", "layer3.2.bn1.bias", "layer3.2.bn1.running_mean", "layer3.2.bn1.running_var", "layer3.2.bn2.weight", "layer3.2.bn2.bias", "layer3.2.bn2.running_mean", "layer3.2.bn2.running_var", "layer3.3.conv3.weight", "layer3.3.bn3.weight", "layer3.3.bn3.bias", "layer3.3.bn3.running_mean", "layer3.3.bn3.running_var", "layer3.3.se_module.fc1.weight", "layer3.3.se_module.fc1.bias", "layer3.3.se_module.fc2.weight", "layer3.3.se_module.fc2.bias", "layer3.3.bn1.weight", "layer3.3.bn1.bias", "layer3.3.bn1.running_mean", "layer3.3.bn1.running_var", "layer3.3.bn2.weight", "layer3.3.bn2.bias", "layer3.3.bn2.running_mean", "layer3.3.bn2.running_var", "layer3.4.conv3.weight", "layer3.4.bn3.weight", "layer3.4.bn3.bias", "layer3.4.bn3.running_mean", "layer3.4.bn3.running_var", "layer3.4.se_module.fc1.weight", "layer3.4.se_module.fc1.bias", "layer3.4.se_module.fc2.weight", "layer3.4.se_module.fc2.bias", "layer3.4.bn1.weight", "layer3.4.bn1.bias", "layer3.4.bn1.running_mean", "layer3.4.bn1.running_var", "layer3.4.bn2.weight", "layer3.4.bn2.bias", "layer3.4.bn2.running_mean", "layer3.4.bn2.running_var", "layer3.5.conv3.weight", "layer3.5.bn3.weight", "layer3.5.bn3.bias", "layer3.5.bn3.running_mean", "layer3.5.bn3.running_var", "layer3.5.se_module.fc1.weight", "layer3.5.se_module.fc1.bias", "layer3.5.se_module.fc2.weight", "layer3.5.se_module.fc2.bias", "layer3.5.bn1.weight", "layer3.5.bn1.bias", "layer3.5.bn1.running_mean", "layer3.5.bn1.running_var", "layer3.5.bn2.weight", "layer3.5.bn2.bias", "layer3.5.bn2.running_mean", "layer3.5.bn2.running_var", "layer4.0.conv3.weight", "layer4.0.bn3.weight", "layer4.0.bn3.bias", "layer4.0.bn3.running_mean", "layer4.0.bn3.running_var", "layer4.0.se_module.fc1.weight", "layer4.0.se_module.fc1.bias", "layer4.0.se_module.fc2.weight", "layer4.0.se_module.fc2.bias", "layer4.0.bn1.weight", "layer4.0.bn1.bias", "layer4.0.bn1.running_mean", "layer4.0.bn1.running_var", "layer4.0.bn2.weight", "layer4.0.bn2.bias", "layer4.0.bn2.running_mean", "layer4.0.bn2.running_var", "layer4.0.downsample.1.weight", "layer4.0.downsample.1.bias", "layer4.0.downsample.1.running_mean", "layer4.0.downsample.1.running_var", "layer4.1.conv3.weight", "layer4.1.bn3.weight", "layer4.1.bn3.bias", "layer4.1.bn3.running_mean", "layer4.1.bn3.running_var", "layer4.1.se_module.fc1.weight", "layer4.1.se_module.fc1.bias", "layer4.1.se_module.fc2.weight", "layer4.1.se_module.fc2.bias", "layer4.1.bn1.weight", "layer4.1.bn1.bias", "layer4.1.bn1.running_mean", "layer4.1.bn1.running_var", "layer4.1.bn2.weight", "layer4.1.bn2.bias", "layer4.1.bn2.running_mean", "layer4.1.bn2.running_var", "layer4.2.conv3.weight", "layer4.2.bn3.weight", "layer4.2.bn3.bias", "layer4.2.bn3.running_mean", "layer4.2.bn3.running_var", "layer4.2.se_module.fc1.weight", "layer4.2.se_module.fc1.bias", "layer4.2.se_module.fc2.weight", "layer4.2.se_module.fc2.bias", "layer4.2.bn1.weight", "layer4.2.bn1.bias", "layer4.2.bn1.running_mean", "layer4.2.bn1.running_var", "layer4.2.bn2.weight", "layer4.2.bn2.bias", "layer4.2.bn2.running_mean", "layer4.2.bn2.running_var". 
	size mismatch for layer1.0.conv1.weight: copying a param with shape torch.Size([128, 64, 1, 1]) from checkpoint, the shape in current model is torch.Size([64, 64, 3, 3]).
	size mismatch for layer1.0.conv2.weight: copying a param with shape torch.Size([128, 4, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 64, 3, 3]).
	size mismatch for layer1.1.conv1.weight: copying a param with shape torch.Size([128, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([64, 64, 3, 3]).
	size mismatch for layer1.1.conv2.weight: copying a param with shape torch.Size([128, 4, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 64, 3, 3]).
	size mismatch for layer1.2.conv1.weight: copying a param with shape torch.Size([128, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([64, 64, 3, 3]).
	size mismatch for layer1.2.conv2.weight: copying a param with shape torch.Size([128, 4, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 64, 3, 3]).
	size mismatch for layer2.0.conv1.weight: copying a param with shape torch.Size([256, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([128, 64, 3, 3]).
	size mismatch for layer2.0.conv2.weight: copying a param with shape torch.Size([256, 8, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 128, 3, 3]).
	size mismatch for layer2.0.downsample.0.weight: copying a param with shape torch.Size([512, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([128, 64, 1, 1]).
	size mismatch for layer2.1.conv1.weight: copying a param with shape torch.Size([256, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([128, 128, 3, 3]).
	size mismatch for layer2.1.conv2.weight: copying a param with shape torch.Size([256, 8, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 128, 3, 3]).
	size mismatch for layer2.2.conv1.weight: copying a param with shape torch.Size([256, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([128, 128, 3, 3]).
	size mismatch for layer2.2.conv2.weight: copying a param with shape torch.Size([256, 8, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 128, 3, 3]).
	size mismatch for layer2.3.conv1.weight: copying a param with shape torch.Size([256, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([128, 128, 3, 3]).
	size mismatch for layer2.3.conv2.weight: copying a param with shape torch.Size([256, 8, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 128, 3, 3]).
	size mismatch for layer3.0.conv1.weight: copying a param with shape torch.Size([512, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([256, 128, 3, 3]).
	size mismatch for layer3.0.conv2.weight: copying a param with shape torch.Size([512, 16, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 256, 3, 3]).
	size mismatch for layer3.0.downsample.0.weight: copying a param with shape torch.Size([1024, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([256, 128, 1, 1]).
	size mismatch for layer3.1.conv1.weight: copying a param with shape torch.Size([512, 1024, 1, 1]) from checkpoint, the shape in current model is torch.Size([256, 256, 3, 3]).
	size mismatch for layer3.1.conv2.weight: copying a param with shape torch.Size([512, 16, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 256, 3, 3]).
	size mismatch for layer3.2.conv1.weight: copying a param with shape torch.Size([512, 1024, 1, 1]) from checkpoint, the shape in current model is torch.Size([256, 256, 3, 3]).
	size mismatch for layer3.2.conv2.weight: copying a param with shape torch.Size([512, 16, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 256, 3, 3]).
	size mismatch for layer3.3.conv1.weight: copying a param with shape torch.Size([512, 1024, 1, 1]) from checkpoint, the shape in current model is torch.Size([256, 256, 3, 3]).
	size mismatch for layer3.3.conv2.weight: copying a param with shape torch.Size([512, 16, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 256, 3, 3]).
	size mismatch for layer3.4.conv1.weight: copying a param with shape torch.Size([512, 1024, 1, 1]) from checkpoint, the shape in current model is torch.Size([256, 256, 3, 3]).
	size mismatch for layer3.4.conv2.weight: copying a param with shape torch.Size([512, 16, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 256, 3, 3]).
	size mismatch for layer3.5.conv1.weight: copying a param with shape torch.Size([512, 1024, 1, 1]) from checkpoint, the shape in current model is torch.Size([256, 256, 3, 3]).
	size mismatch for layer3.5.conv2.weight: copying a param with shape torch.Size([512, 16, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 256, 3, 3]).
	size mismatch for layer4.0.conv1.weight: copying a param with shape torch.Size([1024, 1024, 1, 1]) from checkpoint, the shape in current model is torch.Size([512, 256, 3, 3]).
	size mismatch for layer4.0.conv2.weight: copying a param with shape torch.Size([1024, 32, 3, 3]) from checkpoint, the shape in current model is torch.Size([512, 512, 3, 3]).
	size mismatch for layer4.0.downsample.0.weight: copying a param with shape torch.Size([2048, 1024, 1, 1]) from checkpoint, the shape in current model is torch.Size([512, 256, 1, 1]).
	size mismatch for layer4.1.conv1.weight: copying a param with shape torch.Size([1024, 2048, 1, 1]) from checkpoint, the shape in current model is torch.Size([512, 512, 3, 3]).
	size mismatch for layer4.1.conv2.weight: copying a param with shape torch.Size([1024, 32, 3, 3]) from checkpoint, the shape in current model is torch.Size([512, 512, 3, 3]).
	size mismatch for layer4.2.conv1.weight: copying a param with shape torch.Size([1024, 2048, 1, 1]) from checkpoint, the shape in current model is torch.Size([512, 512, 3, 3]).
	size mismatch for layer4.2.conv2.weight: copying a param with shape torch.Size([1024, 32, 3, 3]) from checkpoint, the shape in current model is torch.Size([512, 512, 3, 3]).

You are trying to load the ResNext state_dict into a ResNet, which won’t work out of the box.
To load the pretrained resnet18 parameters, you could use:

sd = models.resnet18(pretrained=True).state_dict()
model.load_state_dict(sd, strict=False)


for key in sd:
    ref = sd[key]
    actual = model.state_dict().get(key, None)
    if actual is not None:
        print((ref - actual).abs().max())
    else:
        print(key, ', not found in model')

The last loop is just a quick test to show, that the newly added nn.GroupNorm layers won’t be initialized.

thank you @ptrblck
your code works for any resnet
however i want to know how do i change this line of code : model = resnet.ResNet(resnet.BasicBlock, [3, 4, 6, 3], num_classes=5, zero_init_residual=False,
groups=1, width_per_group=64, replace_stride_with_dilation=None,
norm_layer=MyGroupNorm)

to load resnext50?
i can’t pass norm_layer=MyGroupNorm for resnext,if you show an example for resnext50 then it could be great! thank you a lot

@ptrblck i need 1 help from you!
i tried your bn to gn swapping code and it works but here is what happening with me :

  1. without groupnormalization swapping my model works fine
  2. when i swap bn layers with gn(like you did in above message) i see qwk score getting 0.0 sometimes but with batchnorm layers i never get qwk score,why is this happening?

let me repeat again : “everything works fine when i don’t swap bn layers with gn layers”
when i use gn layers my model starts performing very badly but groupnorm paper is telling different story!! what could be the reason for this? any suggestion or groupnorm modification for solving this issue?

I don’t know, why GroupNorm layers might decrease the model performance.
How did you swap them and verify that the model is correctly using these layers?

@ptrblck i just changed my efficientnet model with your model like you posted above,

my previous model that was working fine :

    
class enetv2(nn.Module):
    def __init__(self, backbone, out_dim):
        super(enetv2, self).__init__()
        self.enet = enet.EfficientNet.from_name(backbone)
        self.enet.load_state_dict(torch.load(pretrained_model[backbone]))

        self.myfc = nn.Linear(self.enet._fc.in_features, out_dim)
        self.enet._fc = nn.Identity()

    def extract(self, x):
        return self.enet(x)

    def forward(self, x):
        x = self.extract(x)
        x = self.myfc(x)
        return x

here out_dim = 5

your model :

from torchvision.models import resnet


class MyGroupNorm(nn.Module):
    def __init__(self, num_channels):
        super(MyGroupNorm, self).__init__()
        self.norm = nn.GroupNorm(num_groups=2, num_channels=num_channels,
                                 eps=1e-5, affine=True)
    
    def forward(self, x):
        x = self.norm(x)
        return x

model = resnet.ResNet(resnet.BasicBlock, [2, 2, 2, 2], num_classes=5, zero_init_residual=False,
                 groups=1, width_per_group=64, replace_stride_with_dilation=None,
                 norm_layer=MyGroupNorm)
print(model)

i just simply changed my efficientnet model with your model so that i can use groupnorm and everything else was as before!

Hi

when I use your method I get the following error:

Traceback (most recent call last):
  File "train_discriminator.py", line 170, in <module>
    norm_layer=MyGroupNorm)
  File "/alto/nima/torch-env/lib/python3.6/site-packages/torchvision/models/resnet.py", line 192, in __init__
    nn.init.constant_(m.weight, 1)
  File "/alto/nima/torch-env/lib/python3.6/site-packages/torch/nn/init.py", line 176, in constant_
    return _no_grad_fill_(tensor, val)
  File "/alto/nima/torch-env/lib/python3.6/site-packages/torch/nn/init.py", line 59, in _no_grad_fill_
    return tensor.fill_(val)
AttributeError: 'NoneType' object has no attribute 'fill_'

You would have to handle MyGroupNorm differently, as this custom modules doesn’t contain the .weight attribute, but instead registers .norm, which has the weight and bias, if affine was set to True.
As a quick fix, you could check, if m.weight is not None and initialize it then.