Change the BN to GN in resnet

Hi everyone,
I have a question about how to change the Normlization methond in resnet. When I first look at the code of resnet, I found that there is a attribute named norm_layer, where we could create BN layer. So, I try to initializing the norm_layer with nn.GroupNorm. However, I notice in the code of resnet, we just deliver the name nn.BatchNorm to norm_layer and use it to create our network. Sadly, the GN has two parameters, which means I can’t deliver the parameters of GN to ResNet class. So, my question is if I wanna change the Normlization from BN to GN, should I rewrite the Bottleneck and ResNet.
And I already try it but get some error with the weight shape and network input shape.

def norm2d(num_channels_per_group, planes):
    print("num_channels_per_group:{}".format(num_channels_per_group))
    if num_channels_per_group > 0:
        return GroupNorm(num_channels_per_group, planes, affine=True)
    else:
        return nn.BatchNorm2d(planes)

class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None,
                 group_norm=0):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = norm2d(group_norm, planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = norm2d(group_norm, planes)
        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = norm2d(group_norm, planes * 4)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride
    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out

class ResNet(nn.Module):

    def __init__(self, block, layers, num_classes=1000, group_norm=0):
        self.inplanes = 64
        super(ResNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = norm2d(64, group_norm)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0],
                                       group_norm=group_norm)
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
                                       group_norm=group_norm)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
                                       group_norm=group_norm)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
                                       group_norm=group_norm)
        self.avgpool = nn.AvgPool2d(7, stride=1)
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, GroupNorm):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

        for m in self.modules():
            if isinstance(m, Bottleneck):
                m.bn3.weight.data.fill_(0)
            if isinstance(m, BasicBlock):
                m.bn2.weight.data.fill_(0)

    def _make_layer(self, block, planes, blocks, stride=1, group_norm=0):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                norm2d(planes * block.expansion, group_norm),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample,
                            group_norm))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes, group_norm=group_norm))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x


def _resnet_gn(arch, block, layers, pretrained, num_classes, group_norm, **kwargs):
    model = ResNet(block, layers, num_classes, group_norm, **kwargs)
    if pretrained:
        model.load_state_dict(torch.load(MODEL_DIR))
    return model

class PD2SEModel(nn.Module):
    def __init__(self):
        super(PD2SEModel, self).__init__()
        # res_net_50_base = models.resnet50(pretrained = True)
        res_net_50_gn = _resnet_gn('resnet50', Bottleneck, [3, 4, 6, 3], pretrained=False, num_classes=45, group_norm=16)

        # res_net_50_base = models.resnet50()
        children_list = list(res_net_50_gn.children())
...

And the error is as follow:

Traceback (most recent call last):
File “/home/wzq/Work4money/AI-challenge-plant/code/train1.py”, line 305, in
train_model()
File “/home/wzq/Work4money/AI-challenge-plant/code/train1.py”, line 140, in train_model
out1, out2, out3 = PD2SE(img)
File “/usr/local/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/nn/modules/module.py”, line 547, in call
result = self.forward(*input, **kwargs)
File “/home/wzq/Work4money/AI-challenge-plant/code/network.py”, line 55, in forward
severity_class = self.Layer0(x)
File “/usr/local/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/nn/modules/module.py”, line 547, in call
result = self.forward(*input, **kwargs)
File “/usr/local/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/nn/modules/container.py”, line 92, in forward
input = module(input)
File “/usr/local/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/nn/modules/module.py”, line 547, in call
result = self.forward(*input, **kwargs)
File “/usr/local/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/nn/modules/normalization.py”, line 225, in forward
input, self.num_groups, self.weight, self.bias, self.eps)
File “/usr/local/anaconda3/envs/pytorch/lib/python3.7/site-packages/torch/nn/functional.py”, line 1692, in group_norm
torch.backends.cudnn.enabled)
RuntimeError: Expected weight to be a vector of size equal to the number of channels in input, but got weight of shape [16] and input of shape [32, 64, 112, 112]

Hope for your reply. Thanks a lot.

It seems you are passing the arguments to your norm2d method in ResNet in the wrong order:

self.bn1 = norm2d(64, group_norm)

I assume it should be created as norm2d(group_norm, 64) as done in Bottleneck.

Hi, @ptrblck. Thanks for your reply. I have tried the other way to change it. And it works.The code is as follow:

class ResNet(torchvision.models.resnet.ResNet):
    def __init__(self, block, layers, num_classes=1000, group_norm=False):
        if group_norm:
            norm_layer = lambda x: nn.GroupNorm(32, x)
        else:
            norm_layer = None
        super(ResNet, self).__init__(block, layers, num_classes, norm_layer=norm_layer)
        if not group_norm:
            self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0, ceil_mode=True)  # change
            for i in range(2, 5):
                getattr(self, 'layer%d' % i)[0].conv1.stride = (2, 2)
                getattr(self, 'layer%d' % i)[0].conv2.stride = (1, 1)

But in my assignment, maybe I should try to change the BN in Shufflenet to GN. That’s really puzzled me, since I import the ShuffleUnit from the pytorchcv, and it dose not have the attribute like norm_layer in resnet. By the way, I have tried name_module method, and selecting the BN layer from ShuffleNet. But, it didn’t work, of course. We can’t change the value in tuple.

    modules = model.named_modules()

    for i in modules:
        if isinstance(i[1], nn.BatchNorm2d):
            i[2] = lambda x: nn.GroupNorm(32, x)
            print(i)

So I’m tring to rewrite the ShuffleUnit class in pytorchcv, and add the new attribute to it, just like this:

class ShuffleUnit(pytorchcv.models.shufflenetv2.ShuffleUnit):
    def __init__(self, in_channels, out_channels, downsample=False,
                 use_residual=True, use_se=False, group_norm=False):
        if group_norm:
            norm_layer = lambda x: nn.GroupNorm(32, x)
        else:
            norm_layer = None
        super(ShuffleUnit, self).__init__()

Actually, I stop at this point, I’m not sure this method would work. Is there any advises about this work? The code of ShuffleNet comes from pytorchcv is as follow:

class ShuffleUnit(nn.Module):
    """
    ShuffleNetV2 unit.

    Parameters:
    ----------
    in_channels : int
        Number of input channels.
    out_channels : int
        Number of output channels.
    downsample : bool
        Whether do downsample.
    use_se : bool
        Whether to use SE block.
    use_residual : bool
        Whether to use residual connection.
    """
    def __init__(self,
                 in_channels,
                 out_channels,
                 downsample,
                 use_se,
                 use_residual):
        super(ShuffleUnit, self).__init__()
        self.downsample = downsample
        self.use_se = use_se
        self.use_residual = use_residual
        mid_channels = out_channels // 2

        self.compress_conv1 = conv1x1(
            in_channels=(in_channels if self.downsample else mid_channels),
            out_channels=mid_channels)
        self.compress_bn1 = nn.BatchNorm2d(num_features=mid_channels)
        self.dw_conv2 = depthwise_conv3x3(
            channels=mid_channels,
            stride=(2 if self.downsample else 1))
        self.dw_bn2 = nn.BatchNorm2d(num_features=mid_channels)
        self.expand_conv3 = conv1x1(
            in_channels=mid_channels,
            out_channels=mid_channels)
        self.expand_bn3 = nn.BatchNorm2d(num_features=mid_channels)
        if self.use_se:
            self.se = SEBlock(channels=mid_channels)
        if downsample:
            self.dw_conv4 = depthwise_conv3x3(
                channels=in_channels,
                stride=2)
            self.dw_bn4 = nn.BatchNorm2d(num_features=in_channels)
            self.expand_conv5 = conv1x1(
                in_channels=in_channels,
                out_channels=mid_channels)
            self.expand_bn5 = nn.BatchNorm2d(num_features=mid_channels)

        self.activ = nn.ReLU(inplace=True)
        self.c_shuffle = ChannelShuffle(
            channels=out_channels,
            groups=2)

Thanks a lot.

1 Like

Rewriting the model definition would of course work.
However, using getattr and setattr might be the hacky but faster way:

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 3, 3, 1, 1)
        self.bn1 = nn.BatchNorm2d(3)
        
    def forward(self, x):
        x = self.bn1(self.conv1(x))
        return x

model = MyModel()
print(model)

for name, module in model.named_modules():
    if isinstance(module, nn.BatchNorm2d):
        # Get current bn layer
        bn = getattr(model, name)
        # Create new gn layer
        gn = nn.GroupNorm(1, bn.num_features)
        # Assign gn
        print('Swapping {} with {}'.format(bn, gn))
        setattr(model, name, gn)

print(model)

Let me know, if this would work for your model.

1 Like

Thank you @ptrblck. This really works for me. And before your answering I did not relize that we could use getattr and setattr to change the model. I will take notes about this.
Have a nice day!

@ptrblck i have a efficientnet model and i want to replace all of it’s batchnorm layers with groupnorm,how do i do it? here is my model code :


out_dim = 5
enet_type = 'efficientnet-b0'

pretrained_model = {
    'efficientnet-b0': '../input/efficientnet-pytorch/efficientnet-b0-08094119.pth'
}

    
class enetv2(nn.Module):
    def __init__(self, backbone, out_dim):
        super(enetv2, self).__init__()
        self.enet = enet.EfficientNet.from_name(backbone)
        self.enet.load_state_dict(torch.load(pretrained_model[backbone]))

        self.myfc = nn.Linear(self.enet._fc.in_features, out_dim)
        self.enet._fc = nn.Identity()

    def extract(self, x):
        return self.enet(x)

    def forward(self, x):
        x = self.extract(x)
        x = self.myfc(x)
        return x
model = enetv2(enet_type, out_dim=out_dim)
model = model.to(device)

with your code i get this error :


AttributeError Traceback (most recent call last)
in
2 if isinstance(module, nn.BatchNorm2d):
3 # Get current bn layer
----> 4 bn = getattr(model, name)
5 # Create new gn layer
6 gn = nn.GroupNorm(1, bn.num_features)

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in getattr(self, name)
592 return modules[name]
593 raise AttributeError("’{}’ object has no attribute ‘{}’".format(
–> 594 type(self).name, name))
595
596 def setattr(self, name, value):

AttributeError: ‘enetv2’ object has no attribute ‘enet._bn0’

however model.enet._bn0
gives me this output :
BatchNorm2d(32, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)

My approach is hacky as described and if it’s not working on your particular model, I would recommend to create a custom model by deriving from your base model and to explicitly replace the layers you would like to change.

sorry @ptrblck i have never converted a model into custom model,i don’t know how to do it efficiently,a little help from you would be highly appreciated, i was using efficientnet from here : https://github.com/lukemelas/EfficientNet-PyTorch

@ptrblck i tried to change all the bn layers to gn layers and was able to change but while training model with groupnorm i am getting this error :

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<timed exec> in <module>

<ipython-input-22-cf57c28a75f1> in train_epoch(loader, optimizer)
      9         loss_func = criterion
     10         optimizer.zero_grad()
---> 11         logits = model(data)
     12         loss = loss_func(logits, target)
     13         loss.backward()

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    548             result = self._slow_forward(*input, **kwargs)
    549         else:
--> 550             result = self.forward(*input, **kwargs)
    551         for hook in self._forward_hooks.values():
    552             hook_result = hook(self, input, result)

<ipython-input-10-902386208771> in forward(self, x)
     26 
     27     def forward(self, x):
---> 28         x = self.extract(x)
     29         x = self.myfc(x)
     30         return x

<ipython-input-10-902386208771> in extract(self, x)
     23 
     24     def extract(self, x):
---> 25         return self.enet(x)
     26 
     27     def forward(self, x):

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    548             result = self._slow_forward(*input, **kwargs)
    549         else:
--> 550             result = self.forward(*input, **kwargs)
    551         for hook in self._forward_hooks.values():
    552             hook_result = hook(self, input, result)

/kaggle/input/efficientnet-pytorch/EfficientNet-PyTorch/EfficientNet-PyTorch-master/efficientnet_pytorch/model.py in forward(self, inputs)
    176 
    177         # Convolution layers
--> 178         x = self.extract_features(inputs)
    179 
    180         # Pooling and final linear layer

/kaggle/input/efficientnet-pytorch/EfficientNet-PyTorch/EfficientNet-PyTorch-master/efficientnet_pytorch/model.py in extract_features(self, inputs)
    158 
    159         # Stem
--> 160         x = relu_fn(self._bn0(self._conv_stem(inputs)))
    161 
    162         # Blocks

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    548             result = self._slow_forward(*input, **kwargs)
    549         else:
--> 550             result = self.forward(*input, **kwargs)
    551         for hook in self._forward_hooks.values():
    552             hook_result = hook(self, input, result)

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/normalization.py in forward(self, input)
    223     def forward(self, input):
    224         return F.group_norm(
--> 225             input, self.num_groups, self.weight, self.bias, self.eps)
    226 
    227     def extra_repr(self):

/opt/conda/lib/python3.7/site-packages/torch/nn/functional.py in group_norm(input, num_groups, weight, bias, eps)
   1971         + list(input.size()[2:]))
   1972     return torch.group_norm(input, num_groups, weight, bias, eps,
-> 1973                             torch.backends.cudnn.enabled)
   1974 
   1975 

RuntimeError: expected device cpu but got device cuda:0

here is my full model :

enetv2(
  (enet): EfficientNet(
    (_conv_stem): Conv2dStaticSamePadding(
      3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False
      (static_padding): ZeroPad2d(padding=(0, 1, 0, 1), value=0.0)
    )
    (_bn0): GroupNorm(1, 32, eps=1e-05, affine=True)
    (_blocks): ModuleList(
      (0): MBConvBlock(
        (_depthwise_conv): Conv2dStaticSamePadding(
          32, 32, kernel_size=(3, 3), stride=[1, 1], groups=32, bias=False
          (static_padding): ZeroPad2d(padding=(1, 1, 1, 1), value=0.0)
        )
        (_bn1): GroupNorm(1, 32, eps=1e-05, affine=True)
        (_se_reduce): Conv2dStaticSamePadding(
          32, 8, kernel_size=(1, 1), stride=(1, 1)
          (static_padding): Identity()
        )
        (_se_expand): Conv2dStaticSamePadding(
          8, 32, kernel_size=(1, 1), stride=(1, 1)
          (static_padding): Identity()
        )
        (_project_conv): Conv2dStaticSamePadding(
          32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False
          (static_padding): Identity()
        )
        (_bn2): GroupNorm(1, 16, eps=1e-05, affine=True)
      )
      (1): MBConvBlock(
        (_expand_conv): Conv2dStaticSamePadding(
          16, 96, kernel_size=(1, 1), stride=(1, 1), bias=False
          (static_padding): Identity()
        )
        (_bn0): GroupNorm(1, 96, eps=1e-05, affine=True)
        (_depthwise_conv): Conv2dStaticSamePadding(
          96, 96, kernel_size=(3, 3), stride=[2, 2], groups=96, bias=False
          (static_padding): ZeroPad2d(padding=(0, 1, 0, 1), value=0.0)
        )
        (_bn1): GroupNorm(1, 96, eps=1e-05, affine=True)
        (_se_reduce): Conv2dStaticSamePadding(
          96, 4, kernel_size=(1, 1), stride=(1, 1)
          (static_padding): Identity()
        )
        (_se_expand): Conv2dStaticSamePadding(
          4, 96, kernel_size=(1, 1), stride=(1, 1)
          (static_padding): Identity()
        )
        (_project_conv): Conv2dStaticSamePadding(
          96, 24, kernel_size=(1, 1), stride=(1, 1), bias=False
          (static_padding): Identity()
        )
        (_bn2): GroupNorm(1, 24, eps=1e-05, affine=True)
      )
      (2): MBConvBlock(
        (_expand_conv): Conv2dStaticSamePadding(
          24, 144, kernel_size=(1, 1), stride=(1, 1), bias=False
          (static_padding): Identity()
        )
        (_bn0): GroupNorm(1, 144, eps=1e-05, affine=True)
        (_depthwise_conv): Conv2dStaticSamePadding(
          144, 144, kernel_size=(3, 3), stride=(1, 1), groups=144, bias=False
          (static_padding): ZeroPad2d(padding=(1, 1, 1, 1), value=0.0)
        )
        (_bn1): GroupNorm(1, 144, eps=1e-05, affine=True)
        (_se_reduce): Conv2dStaticSamePadding(
          144, 6, kernel_size=(1, 1), stride=(1, 1)
          (static_padding): Identity()
        )
        (_se_expand): Conv2dStaticSamePadding(
          6, 144, kernel_size=(1, 1), stride=(1, 1)
          (static_padding): Identity()
        )
        (_project_conv): Conv2dStaticSamePadding(
          144, 24, kernel_size=(1, 1), stride=(1, 1), bias=False
          (static_padding): Identity()
        )
        (_bn2): GroupNorm(1, 24, eps=1e-05, affine=True)
      )
      (3): MBConvBlock(
        (_expand_conv): Conv2dStaticSamePadding(
          24, 144, kernel_size=(1, 1), stride=(1, 1), bias=False
          (static_padding): Identity()
        )
        (_bn0): GroupNorm(1, 144, eps=1e-05, affine=True)
        (_depthwise_conv): Conv2dStaticSamePadding(
          144, 144, kernel_size=(5, 5), stride=[2, 2], groups=144, bias=False
          (static_padding): ZeroPad2d(padding=(1, 2, 1, 2), value=0.0)
        )
        (_bn1): GroupNorm(1, 144, eps=1e-05, affine=True)
        (_se_reduce): Conv2dStaticSamePadding(
          144, 6, kernel_size=(1, 1), stride=(1, 1)
          (static_padding): Identity()
        )
        (_se_expand): Conv2dStaticSamePadding(
          6, 144, kernel_size=(1, 1), stride=(1, 1)
          (static_padding): Identity()
        )
        (_project_conv): Conv2dStaticSamePadding(
          144, 40, kernel_size=(1, 1), stride=(1, 1), bias=False
          (static_padding): Identity()
        )
        (_bn2): GroupNorm(1, 40, eps=1e-05, affine=True)
      )
      (4): MBConvBlock(
        (_expand_conv): Conv2dStaticSamePadding(
          40, 240, kernel_size=(1, 1), stride=(1, 1), bias=False
          (static_padding): Identity()
        )
        (_bn0): GroupNorm(1, 240, eps=1e-05, affine=True)
        (_depthwise_conv): Conv2dStaticSamePadding(
          240, 240, kernel_size=(5, 5), stride=(1, 1), groups=240, bias=False
          (static_padding): ZeroPad2d(padding=(2, 2, 2, 2), value=0.0)
        )
        (_bn1): GroupNorm(1, 240, eps=1e-05, affine=True)
        (_se_reduce): Conv2dStaticSamePadding(
          240, 10, kernel_size=(1, 1), stride=(1, 1)
          (static_padding): Identity()
        )
        (_se_expand): Conv2dStaticSamePadding(
          10, 240, kernel_size=(1, 1), stride=(1, 1)
          (static_padding): Identity()
        )
        (_project_conv): Conv2dStaticSamePadding(
          240, 40, kernel_size=(1, 1), stride=(1, 1), bias=False
          (static_padding): Identity()
        )
        (_bn2): GroupNorm(1, 40, eps=1e-05, affine=True)
      )
      (5): MBConvBlock(
        (_expand_conv): Conv2dStaticSamePadding(
          40, 240, kernel_size=(1, 1), stride=(1, 1), bias=False
          (static_padding): Identity()
        )
        (_bn0): GroupNorm(1, 240, eps=1e-05, affine=True)
        (_depthwise_conv): Conv2dStaticSamePadding(
          240, 240, kernel_size=(3, 3), stride=[2, 2], groups=240, bias=False
          (static_padding): ZeroPad2d(padding=(0, 1, 0, 1), value=0.0)
        )
        (_bn1): GroupNorm(1, 240, eps=1e-05, affine=True)
        (_se_reduce): Conv2dStaticSamePadding(
          240, 10, kernel_size=(1, 1), stride=(1, 1)
          (static_padding): Identity()
        )
        (_se_expand): Conv2dStaticSamePadding(
          10, 240, kernel_size=(1, 1), stride=(1, 1)
          (static_padding): Identity()
        )
        (_project_conv): Conv2dStaticSamePadding(
          240, 80, kernel_size=(1, 1), stride=(1, 1), bias=False
          (static_padding): Identity()
        )
        (_bn2): GroupNorm(1, 80, eps=1e-05, affine=True)
      )
      (6): MBConvBlock(
        (_expand_conv): Conv2dStaticSamePadding(
          80, 480, kernel_size=(1, 1), stride=(1, 1), bias=False
          (static_padding): Identity()
        )
        (_bn0): GroupNorm(1, 480, eps=1e-05, affine=True)
        (_depthwise_conv): Conv2dStaticSamePadding(
          480, 480, kernel_size=(3, 3), stride=(1, 1), groups=480, bias=False
          (static_padding): ZeroPad2d(padding=(1, 1, 1, 1), value=0.0)
        )
        (_bn1): GroupNorm(1, 480, eps=1e-05, affine=True)
        (_se_reduce): Conv2dStaticSamePadding(
          480, 20, kernel_size=(1, 1), stride=(1, 1)
          (static_padding): Identity()
        )
        (_se_expand): Conv2dStaticSamePadding(
          20, 480, kernel_size=(1, 1), stride=(1, 1)
          (static_padding): Identity()
        )
        (_project_conv): Conv2dStaticSamePadding(
          480, 80, kernel_size=(1, 1), stride=(1, 1), bias=False
          (static_padding): Identity()
        )
        (_bn2): GroupNorm(1, 80, eps=1e-05, affine=True)
      )
      (7): MBConvBlock(
        (_expand_conv): Conv2dStaticSamePadding(
          80, 480, kernel_size=(1, 1), stride=(1, 1), bias=False
          (static_padding): Identity()
        )
        (_bn0): GroupNorm(1, 480, eps=1e-05, affine=True)
        (_depthwise_conv): Conv2dStaticSamePadding(
          480, 480, kernel_size=(3, 3), stride=(1, 1), groups=480, bias=False
          (static_padding): ZeroPad2d(padding=(1, 1, 1, 1), value=0.0)
        )
        (_bn1): GroupNorm(1, 480, eps=1e-05, affine=True)
        (_se_reduce): Conv2dStaticSamePadding(
          480, 20, kernel_size=(1, 1), stride=(1, 1)
          (static_padding): Identity()
        )
        (_se_expand): Conv2dStaticSamePadding(
          20, 480, kernel_size=(1, 1), stride=(1, 1)
          (static_padding): Identity()
        )
        (_project_conv): Conv2dStaticSamePadding(
          480, 80, kernel_size=(1, 1), stride=(1, 1), bias=False
          (static_padding): Identity()
        )
        (_bn2): GroupNorm(1, 80, eps=1e-05, affine=True)
      )
      (8): MBConvBlock(
        (_expand_conv): Conv2dStaticSamePadding(
          80, 480, kernel_size=(1, 1), stride=(1, 1), bias=False
          (static_padding): Identity()
        )
        (_bn0): GroupNorm(1, 480, eps=1e-05, affine=True)
        (_depthwise_conv): Conv2dStaticSamePadding(
          480, 480, kernel_size=(5, 5), stride=[1, 1], groups=480, bias=False
          (static_padding): ZeroPad2d(padding=(2, 2, 2, 2), value=0.0)
        )
        (_bn1): GroupNorm(1, 480, eps=1e-05, affine=True)
        (_se_reduce): Conv2dStaticSamePadding(
          480, 20, kernel_size=(1, 1), stride=(1, 1)
          (static_padding): Identity()
        )
        (_se_expand): Conv2dStaticSamePadding(
          20, 480, kernel_size=(1, 1), stride=(1, 1)
          (static_padding): Identity()
        )
        (_project_conv): Conv2dStaticSamePadding(
          480, 112, kernel_size=(1, 1), stride=(1, 1), bias=False
          (static_padding): Identity()
        )
        (_bn2): GroupNorm(1, 112, eps=1e-05, affine=True)
      )
      (9): MBConvBlock(
        (_expand_conv): Conv2dStaticSamePadding(
          112, 672, kernel_size=(1, 1), stride=(1, 1), bias=False
          (static_padding): Identity()
        )
        (_bn0): GroupNorm(1, 672, eps=1e-05, affine=True)
        (_depthwise_conv): Conv2dStaticSamePadding(
          672, 672, kernel_size=(5, 5), stride=(1, 1), groups=672, bias=False
          (static_padding): ZeroPad2d(padding=(2, 2, 2, 2), value=0.0)
        )
        (_bn1): GroupNorm(1, 672, eps=1e-05, affine=True)
        (_se_reduce): Conv2dStaticSamePadding(
          672, 28, kernel_size=(1, 1), stride=(1, 1)
          (static_padding): Identity()
        )
        (_se_expand): Conv2dStaticSamePadding(
          28, 672, kernel_size=(1, 1), stride=(1, 1)
          (static_padding): Identity()
        )
        (_project_conv): Conv2dStaticSamePadding(
          672, 112, kernel_size=(1, 1), stride=(1, 1), bias=False
          (static_padding): Identity()
        )
        (_bn2): GroupNorm(1, 112, eps=1e-05, affine=True)
      )
      (10): MBConvBlock(
        (_expand_conv): Conv2dStaticSamePadding(
          112, 672, kernel_size=(1, 1), stride=(1, 1), bias=False
          (static_padding): Identity()
        )
        (_bn0): GroupNorm(1, 672, eps=1e-05, affine=True)
        (_depthwise_conv): Conv2dStaticSamePadding(
          672, 672, kernel_size=(5, 5), stride=(1, 1), groups=672, bias=False
          (static_padding): ZeroPad2d(padding=(2, 2, 2, 2), value=0.0)
        )
        (_bn1): GroupNorm(1, 672, eps=1e-05, affine=True)
        (_se_reduce): Conv2dStaticSamePadding(
          672, 28, kernel_size=(1, 1), stride=(1, 1)
          (static_padding): Identity()
        )
        (_se_expand): Conv2dStaticSamePadding(
          28, 672, kernel_size=(1, 1), stride=(1, 1)
          (static_padding): Identity()
        )
        (_project_conv): Conv2dStaticSamePadding(
          672, 112, kernel_size=(1, 1), stride=(1, 1), bias=False
          (static_padding): Identity()
        )
        (_bn2): GroupNorm(1, 112, eps=1e-05, affine=True)
      )
      (11): MBConvBlock(
        (_expand_conv): Conv2dStaticSamePadding(
          112, 672, kernel_size=(1, 1), stride=(1, 1), bias=False
          (static_padding): Identity()
        )
        (_bn0): GroupNorm(1, 672, eps=1e-05, affine=True)
        (_depthwise_conv): Conv2dStaticSamePadding(
          672, 672, kernel_size=(5, 5), stride=[2, 2], groups=672, bias=False
          (static_padding): ZeroPad2d(padding=(1, 2, 1, 2), value=0.0)
        )
        (_bn1): GroupNorm(1, 672, eps=1e-05, affine=True)
        (_se_reduce): Conv2dStaticSamePadding(
          672, 28, kernel_size=(1, 1), stride=(1, 1)
          (static_padding): Identity()
        )
        (_se_expand): Conv2dStaticSamePadding(
          28, 672, kernel_size=(1, 1), stride=(1, 1)
          (static_padding): Identity()
        )
        (_project_conv): Conv2dStaticSamePadding(
          672, 192, kernel_size=(1, 1), stride=(1, 1), bias=False
          (static_padding): Identity()
        )
        (_bn2): GroupNorm(1, 192, eps=1e-05, affine=True)
      )
      (12): MBConvBlock(
        (_expand_conv): Conv2dStaticSamePadding(
          192, 1152, kernel_size=(1, 1), stride=(1, 1), bias=False
          (static_padding): Identity()
        )
        (_bn0): GroupNorm(1, 1152, eps=1e-05, affine=True)
        (_depthwise_conv): Conv2dStaticSamePadding(
          1152, 1152, kernel_size=(5, 5), stride=(1, 1), groups=1152, bias=False
          (static_padding): ZeroPad2d(padding=(2, 2, 2, 2), value=0.0)
        )
        (_bn1): GroupNorm(1, 1152, eps=1e-05, affine=True)
        (_se_reduce): Conv2dStaticSamePadding(
          1152, 48, kernel_size=(1, 1), stride=(1, 1)
          (static_padding): Identity()
        )
        (_se_expand): Conv2dStaticSamePadding(
          48, 1152, kernel_size=(1, 1), stride=(1, 1)
          (static_padding): Identity()
        )
        (_project_conv): Conv2dStaticSamePadding(
          1152, 192, kernel_size=(1, 1), stride=(1, 1), bias=False
          (static_padding): Identity()
        )
        (_bn2): GroupNorm(1, 192, eps=1e-05, affine=True)
      )
      (13): MBConvBlock(
        (_expand_conv): Conv2dStaticSamePadding(
          192, 1152, kernel_size=(1, 1), stride=(1, 1), bias=False
          (static_padding): Identity()
        )
        (_bn0): GroupNorm(1, 1152, eps=1e-05, affine=True)
        (_depthwise_conv): Conv2dStaticSamePadding(
          1152, 1152, kernel_size=(5, 5), stride=(1, 1), groups=1152, bias=False
          (static_padding): ZeroPad2d(padding=(2, 2, 2, 2), value=0.0)
        )
        (_bn1): GroupNorm(1, 1152, eps=1e-05, affine=True)
        (_se_reduce): Conv2dStaticSamePadding(
          1152, 48, kernel_size=(1, 1), stride=(1, 1)
          (static_padding): Identity()
        )
        (_se_expand): Conv2dStaticSamePadding(
          48, 1152, kernel_size=(1, 1), stride=(1, 1)
          (static_padding): Identity()
        )
        (_project_conv): Conv2dStaticSamePadding(
          1152, 192, kernel_size=(1, 1), stride=(1, 1), bias=False
          (static_padding): Identity()
        )
        (_bn2): GroupNorm(1, 192, eps=1e-05, affine=True)
      )
      (14): MBConvBlock(
        (_expand_conv): Conv2dStaticSamePadding(
          192, 1152, kernel_size=(1, 1), stride=(1, 1), bias=False
          (static_padding): Identity()
        )
        (_bn0): GroupNorm(1, 1152, eps=1e-05, affine=True)
        (_depthwise_conv): Conv2dStaticSamePadding(
          1152, 1152, kernel_size=(5, 5), stride=(1, 1), groups=1152, bias=False
          (static_padding): ZeroPad2d(padding=(2, 2, 2, 2), value=0.0)
        )
        (_bn1): GroupNorm(1, 1152, eps=1e-05, affine=True)
        (_se_reduce): Conv2dStaticSamePadding(
          1152, 48, kernel_size=(1, 1), stride=(1, 1)
          (static_padding): Identity()
        )
        (_se_expand): Conv2dStaticSamePadding(
          48, 1152, kernel_size=(1, 1), stride=(1, 1)
          (static_padding): Identity()
        )
        (_project_conv): Conv2dStaticSamePadding(
          1152, 192, kernel_size=(1, 1), stride=(1, 1), bias=False
          (static_padding): Identity()
        )
        (_bn2): GroupNorm(1, 192, eps=1e-05, affine=True)
      )
      (15): MBConvBlock(
        (_expand_conv): Conv2dStaticSamePadding(
          192, 1152, kernel_size=(1, 1), stride=(1, 1), bias=False
          (static_padding): Identity()
        )
        (_bn0): GroupNorm(1, 1152, eps=1e-05, affine=True)
        (_depthwise_conv): Conv2dStaticSamePadding(
          1152, 1152, kernel_size=(3, 3), stride=[1, 1], groups=1152, bias=False
          (static_padding): ZeroPad2d(padding=(1, 1, 1, 1), value=0.0)
        )
        (_bn1): GroupNorm(1, 1152, eps=1e-05, affine=True)
        (_se_reduce): Conv2dStaticSamePadding(
          1152, 48, kernel_size=(1, 1), stride=(1, 1)
          (static_padding): Identity()
        )
        (_se_expand): Conv2dStaticSamePadding(
          48, 1152, kernel_size=(1, 1), stride=(1, 1)
          (static_padding): Identity()
        )
        (_project_conv): Conv2dStaticSamePadding(
          1152, 320, kernel_size=(1, 1), stride=(1, 1), bias=False
          (static_padding): Identity()
        )
        (_bn2): GroupNorm(1, 320, eps=1e-05, affine=True)
      )
    )
    (_conv_head): Conv2dStaticSamePadding(
      320, 1280, kernel_size=(1, 1), stride=(1, 1), bias=False
      (static_padding): Identity()
    )
    (_bn1): GroupNorm(1, 1280, eps=1e-05, affine=True)
    (_fc): Identity()
  )
  (myfc): Linear(in_features=1280, out_features=5, bias=True)
  (avg_pool): GeM(p=3.0000, eps=1e-06)
)






thanks, i just solved the error,so after replacing bn with gn i had to do model = model.to(device)

1 Like

Hi, i know it’s way to late.

But for everyone, who is trying to replace batch norm with group norm for more complex models like you wanted to do, you can use the following hacky approach :

def batch_norm_to_group_norm(layer):
    """Iterates over a whole model (or layer of a model) and replaces every batch norm 2D with a group norm

    Args:
        layer: model or one layer of a model like resnet34.layer1 or Sequential(), ...
    """
    for name, module in layer.named_modules():
        if name:
            try:
                # name might be something like: model.layer1.sequential.0.conv1 --> this wont work. Except this case
                sub_layer = getattr(layer, name)
                if isinstance(sub_layer, torch.nn.BatchNorm2d):
                    num_channels = sub_layer.num_features
                    # first level of current layer or model contains a batch norm --> replacing.
                    layer._modules[name] = torch.nn.GroupNorm(constants.GROUP_NORM_LOOKUP[num_channels], num_channels)
            except AttributeError:
                # go deeper: set name to layer1, getattr will return layer1 --> call this func again
                name = name.split('.')[0]
                sub_layer = getattr(layer, name)
                sub_layer = batch_norm_to_group_norm(sub_layer)
                layer.__setattr__(name=name, value=sub_layer)
    return layer

And the GROUP_NORM_LOOKUP is build as follows:

# group norm paper stated, that 16 channels per group are best
# group norm paper stated, that 8 or 32 groups are best
# channels per group = num_features / num_group
# channels per group have less influence when being >= 8 so here i try to set optimal values.
# Paper values:
# groups (G)
# 64   32   16   8    4    2    1 (=LN)
# 24.6 24.1 24.6 24.4 24.6 24.7 25.3  # validation error in %
# 0.5  -    0.5  0.3  0.5  0.6  1.2 
# channels per group
# 64   32   16   8    4    2    1 (=IN)
# 24.4 24.5 24.2 24.3 24.8 25.6 28.4  # validation error in %
# 0.2  0.3  -    0.1  0.6  1.4  4.2

# num_channels: num_groups
GROUP_NORM_LOOKUP = {
    16: 2,  # -> channels per group: 8
    32: 4,  # -> channels per group: 8
    64: 8,  # -> channels per group: 8
    128: 8,  # -> channels per group: 16
    256: 16,  # -> channels per group: 16
    512: 32,  # -> channels per group: 16
    1024: 32,  # -> channels per group: 32
    2048: 32,  # -> channels per group: 64
}

Of course you can use a different approach for setting the number of groups.

1 Like

Hi, thanks for your answer, but I have a doubt: What does the last line of code do?
layer.setattr(name=name, value=sub_layer)

Hi, suppose for the first call of this function that layer is a neural network containing a nn.Sequential() with batch norm.
While iterating we will at some point hit the sequential layer which will raise the AttributeError which then triggers the recursuve part.
In this case the function passes the current sub_layer (nn.Sequential()) separated from the actual model and changes every batch norm to a group norm. Now the actual model has not changed and so we need to overwrite the original nn.Sequential() with the returned one.

I think you are wondering about the name right? The name returned by layer.named_modules() is a chained name as described in my function (i can’t edit the post but model.layer1.sequential… is actually not correct. It should only show the theoretical build. The correct name would be layer1.0…).
Example with a ResNet: The network does have some attributes like self.layer1 which is a sequential. To iterate over the modules, pytorch returns name = layer1.0 or layer1.1 and so on… where layer1 is the attribute of the ResNet model and 0 is the attribute of layer1. Zero in that case is the index of the first module in the sequential as a string and the module is then added like a key value pair.

See the sequential init which is build as follows:

    def __init__(self, *args):
        super(Sequential, self).__init__()
        if len(args) == 1 and isinstance(args[0], OrderedDict):
            for key, module in args[0].items():
                self.add_module(key, module)
        else:
            for idx, module in enumerate(args):
                self.add_module(str(idx), module)

The sequential actually adds the module by passing str(idx).

Actually you will overwrite the resnet.layer1 attribute multiple times but since this is done in no time, i did not insert another mechanism to simply check if we already processed the layer1 or not.

So as a result:

  1. i split the name to get layer1 out of layer1.0
  2. then i get the sequential by getattr
  3. i change batch norm to group norm inside the sequential
  4. i add the correct layer1 to the ResNet by setattr.

Thank you very much for your detailed answer. Is it true that if you modify the bn of sub_layer, the actual_layer will also be modified, so we don’t need to use the second setattr (layer.setattr(name=name, value=sub_layer))?

Ah now i got your point.
I did not thought about that but actually you are right. The setattr is not needed here. Thanks for pointing out :slight_smile: