Take Resnet50 in torchvision as an example, I want to change all the BatchNorm2d to GroupNorm . How can I implement this efficiently
@ptrblck can you please help me here?
i am unable to do this as well, i tried this :
import torchvision.models as models
model = models.resnet18()
#then this :
for name, module in model.named_modules():
if isinstance(module, nn.BatchNorm2d):
# Get current bn layer
bn = getattr(model, name)
# Create new gn layer
gn = nn.GroupNorm(1, bn.num_features)
# Assign gn
print('Swapping {} with {}'.format(bn, gn))
setattr(model, name, gn)
print(model)
and it gives this error :
Swapping BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) with GroupNorm(1, 64, eps=1e-05, affine=True)
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-26-dc2f23e093cc> in <module>
2 if isinstance(module, nn.BatchNorm2d):
3 # Get current bn layer
----> 4 bn = getattr(model, name)
5 # Create new gn layer
6 gn = nn.GroupNorm(1, bn.num_features)
/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in __getattr__(self, name)
592 return modules[name]
593 raise AttributeError("'{}' object has no attribute '{}'".format(
--> 594 type(self).__name__, name))
595
596 def __setattr__(self, name, value):
AttributeError: 'ResNet' object has no attribute 'layer1.0.bn1'
As described in this post, where this approach was also posted, I mentioned that this approach is hacky and would work only for simple modules.
If you want to properly swap the normalization layers, you should instead write a custom nn.Module
, derive from the resnet as the base class, and change the normalization layers in the __init__
method.
You could reuse the forward
without changing it.
@ptrblck if you could give an working example with a torchvision model then it could help a lot
i searched in google,posted in github also in stackoverflow and got answer nowhere so a working example demo could be of great help,thank you anyway
The current ResNet
implementation seems to accept different normalization layers as an argument (you might need to update torchvision
, if your installed version doesn’t have this argument).
Here is a small code snippet:
from torchvision.models import resnet
class MyGroupNorm(nn.Module):
def __init__(self, num_channels):
super(MyGroupNorm, self).__init__()
self.norm = nn.GroupNorm(num_groups=2, num_channels=num_channels,
eps=1e-5, affine=True)
def forward(self, x):
x = self.norm(x)
return x
model = resnet.ResNet(resnet.BasicBlock, [2, 2, 2, 2], num_classes=1000, zero_init_residual=False,
groups=1, width_per_group=64, replace_stride_with_dilation=None,
norm_layer=MyGroupNorm)
print(model)
Note that I needed to define the proxy MyGroupNorm
module, since the initialization of the norm_layer
will only take one argument (as seen in this line of code), while nn.GroupNorm
needs two.
@ptrblck thank you a lot.
how do i load weight of pretrained model?
i tried this :
from torchvision.models import resnet
class MyGroupNorm(nn.Module):
def __init__(self, num_channels):
super(MyGroupNorm, self).__init__()
self.norm = nn.GroupNorm(num_groups=2, num_channels=num_channels,
eps=1e-5, affine=True)
def forward(self, x):
x = self.norm(x)
return x
model = resnet.ResNet(resnet.BasicBlock, [3, 4, 6, 3], num_classes=5, zero_init_residual=False,
groups=1, width_per_group=64, replace_stride_with_dilation=None,
norm_layer=MyGroupNorm)
print(model)
and then this to load pretrained weight of resnext50 :
weights_path = '../input/pytorch-se-resnext/se_resnext50_32x4d-a260b3a4.pth'
model.load_state_dict(torch.load(weights_path))
but it gave me this error :
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-30-2c5332b5515b> in <module>
1 weights_path = '../input/pytorch-se-resnext/se_resnext50_32x4d-a260b3a4.pth'
----> 2 model.load_state_dict(torch.load(weights_path))
/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
845 if len(error_msgs) > 0:
846 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
--> 847 self.__class__.__name__, "\n\t".join(error_msgs)))
848 return _IncompatibleKeys(missing_keys, unexpected_keys)
849
RuntimeError: Error(s) in loading state_dict for ResNet:
Missing key(s) in state_dict: "conv1.weight", "bn1.norm.weight", "bn1.norm.bias", "layer1.0.bn1.norm.weight", "layer1.0.bn1.norm.bias", "layer1.0.bn2.norm.weight", "layer1.0.bn2.norm.bias", "layer1.1.bn1.norm.weight", "layer1.1.bn1.norm.bias", "layer1.1.bn2.norm.weight", "layer1.1.bn2.norm.bias", "layer1.2.bn1.norm.weight", "layer1.2.bn1.norm.bias", "layer1.2.bn2.norm.weight", "layer1.2.bn2.norm.bias", "layer2.0.bn1.norm.weight", "layer2.0.bn1.norm.bias", "layer2.0.bn2.norm.weight", "layer2.0.bn2.norm.bias", "layer2.0.downsample.1.norm.weight", "layer2.0.downsample.1.norm.bias", "layer2.1.bn1.norm.weight", "layer2.1.bn1.norm.bias", "layer2.1.bn2.norm.weight", "layer2.1.bn2.norm.bias", "layer2.2.bn1.norm.weight", "layer2.2.bn1.norm.bias", "layer2.2.bn2.norm.weight", "layer2.2.bn2.norm.bias", "layer2.3.bn1.norm.weight", "layer2.3.bn1.norm.bias", "layer2.3.bn2.norm.weight", "layer2.3.bn2.norm.bias", "layer3.0.bn1.norm.weight", "layer3.0.bn1.norm.bias", "layer3.0.bn2.norm.weight", "layer3.0.bn2.norm.bias", "layer3.0.downsample.1.norm.weight", "layer3.0.downsample.1.norm.bias", "layer3.1.bn1.norm.weight", "layer3.1.bn1.norm.bias", "layer3.1.bn2.norm.weight", "layer3.1.bn2.norm.bias", "layer3.2.bn1.norm.weight", "layer3.2.bn1.norm.bias", "layer3.2.bn2.norm.weight", "layer3.2.bn2.norm.bias", "layer3.3.bn1.norm.weight", "layer3.3.bn1.norm.bias", "layer3.3.bn2.norm.weight", "layer3.3.bn2.norm.bias", "layer3.4.bn1.norm.weight", "layer3.4.bn1.norm.bias", "layer3.4.bn2.norm.weight", "layer3.4.bn2.norm.bias", "layer3.5.bn1.norm.weight", "layer3.5.bn1.norm.bias", "layer3.5.bn2.norm.weight", "layer3.5.bn2.norm.bias", "layer4.0.bn1.norm.weight", "layer4.0.bn1.norm.bias", "layer4.0.bn2.norm.weight", "layer4.0.bn2.norm.bias", "layer4.0.downsample.1.norm.weight", "layer4.0.downsample.1.norm.bias", "layer4.1.bn1.norm.weight", "layer4.1.bn1.norm.bias", "layer4.1.bn2.norm.weight", "layer4.1.bn2.norm.bias", "layer4.2.bn1.norm.weight", "layer4.2.bn1.norm.bias", "layer4.2.bn2.norm.weight", "layer4.2.bn2.norm.bias", "fc.weight", "fc.bias".
Unexpected key(s) in state_dict: "layer0.conv1.weight", "layer0.bn1.weight", "layer0.bn1.bias", "layer0.bn1.running_mean", "layer0.bn1.running_var", "last_linear.weight", "last_linear.bias", "layer1.0.conv3.weight", "layer1.0.bn3.weight", "layer1.0.bn3.bias", "layer1.0.bn3.running_mean", "layer1.0.bn3.running_var", "layer1.0.se_module.fc1.weight", "layer1.0.se_module.fc1.bias", "layer1.0.se_module.fc2.weight", "layer1.0.se_module.fc2.bias", "layer1.0.downsample.0.weight", "layer1.0.downsample.1.weight", "layer1.0.downsample.1.bias", "layer1.0.downsample.1.running_mean", "layer1.0.downsample.1.running_var", "layer1.0.bn1.weight", "layer1.0.bn1.bias", "layer1.0.bn1.running_mean", "layer1.0.bn1.running_var", "layer1.0.bn2.weight", "layer1.0.bn2.bias", "layer1.0.bn2.running_mean", "layer1.0.bn2.running_var", "layer1.1.conv3.weight", "layer1.1.bn3.weight", "layer1.1.bn3.bias", "layer1.1.bn3.running_mean", "layer1.1.bn3.running_var", "layer1.1.se_module.fc1.weight", "layer1.1.se_module.fc1.bias", "layer1.1.se_module.fc2.weight", "layer1.1.se_module.fc2.bias", "layer1.1.bn1.weight", "layer1.1.bn1.bias", "layer1.1.bn1.running_mean", "layer1.1.bn1.running_var", "layer1.1.bn2.weight", "layer1.1.bn2.bias", "layer1.1.bn2.running_mean", "layer1.1.bn2.running_var", "layer1.2.conv3.weight", "layer1.2.bn3.weight", "layer1.2.bn3.bias", "layer1.2.bn3.running_mean", "layer1.2.bn3.running_var", "layer1.2.se_module.fc1.weight", "layer1.2.se_module.fc1.bias", "layer1.2.se_module.fc2.weight", "layer1.2.se_module.fc2.bias", "layer1.2.bn1.weight", "layer1.2.bn1.bias", "layer1.2.bn1.running_mean", "layer1.2.bn1.running_var", "layer1.2.bn2.weight", "layer1.2.bn2.bias", "layer1.2.bn2.running_mean", "layer1.2.bn2.running_var", "layer2.0.conv3.weight", "layer2.0.bn3.weight", "layer2.0.bn3.bias", "layer2.0.bn3.running_mean", "layer2.0.bn3.running_var", "layer2.0.se_module.fc1.weight", "layer2.0.se_module.fc1.bias", "layer2.0.se_module.fc2.weight", "layer2.0.se_module.fc2.bias", "layer2.0.bn1.weight", "layer2.0.bn1.bias", "layer2.0.bn1.running_mean", "layer2.0.bn1.running_var", "layer2.0.bn2.weight", "layer2.0.bn2.bias", "layer2.0.bn2.running_mean", "layer2.0.bn2.running_var", "layer2.0.downsample.1.weight", "layer2.0.downsample.1.bias", "layer2.0.downsample.1.running_mean", "layer2.0.downsample.1.running_var", "layer2.1.conv3.weight", "layer2.1.bn3.weight", "layer2.1.bn3.bias", "layer2.1.bn3.running_mean", "layer2.1.bn3.running_var", "layer2.1.se_module.fc1.weight", "layer2.1.se_module.fc1.bias", "layer2.1.se_module.fc2.weight", "layer2.1.se_module.fc2.bias", "layer2.1.bn1.weight", "layer2.1.bn1.bias", "layer2.1.bn1.running_mean", "layer2.1.bn1.running_var", "layer2.1.bn2.weight", "layer2.1.bn2.bias", "layer2.1.bn2.running_mean", "layer2.1.bn2.running_var", "layer2.2.conv3.weight", "layer2.2.bn3.weight", "layer2.2.bn3.bias", "layer2.2.bn3.running_mean", "layer2.2.bn3.running_var", "layer2.2.se_module.fc1.weight", "layer2.2.se_module.fc1.bias", "layer2.2.se_module.fc2.weight", "layer2.2.se_module.fc2.bias", "layer2.2.bn1.weight", "layer2.2.bn1.bias", "layer2.2.bn1.running_mean", "layer2.2.bn1.running_var", "layer2.2.bn2.weight", "layer2.2.bn2.bias", "layer2.2.bn2.running_mean", "layer2.2.bn2.running_var", "layer2.3.conv3.weight", "layer2.3.bn3.weight", "layer2.3.bn3.bias", "layer2.3.bn3.running_mean", "layer2.3.bn3.running_var", "layer2.3.se_module.fc1.weight", "layer2.3.se_module.fc1.bias", "layer2.3.se_module.fc2.weight", "layer2.3.se_module.fc2.bias", "layer2.3.bn1.weight", "layer2.3.bn1.bias", "layer2.3.bn1.running_mean", "layer2.3.bn1.running_var", "layer2.3.bn2.weight", "layer2.3.bn2.bias", "layer2.3.bn2.running_mean", "layer2.3.bn2.running_var", "layer3.0.conv3.weight", "layer3.0.bn3.weight", "layer3.0.bn3.bias", "layer3.0.bn3.running_mean", "layer3.0.bn3.running_var", "layer3.0.se_module.fc1.weight", "layer3.0.se_module.fc1.bias", "layer3.0.se_module.fc2.weight", "layer3.0.se_module.fc2.bias", "layer3.0.bn1.weight", "layer3.0.bn1.bias", "layer3.0.bn1.running_mean", "layer3.0.bn1.running_var", "layer3.0.bn2.weight", "layer3.0.bn2.bias", "layer3.0.bn2.running_mean", "layer3.0.bn2.running_var", "layer3.0.downsample.1.weight", "layer3.0.downsample.1.bias", "layer3.0.downsample.1.running_mean", "layer3.0.downsample.1.running_var", "layer3.1.conv3.weight", "layer3.1.bn3.weight", "layer3.1.bn3.bias", "layer3.1.bn3.running_mean", "layer3.1.bn3.running_var", "layer3.1.se_module.fc1.weight", "layer3.1.se_module.fc1.bias", "layer3.1.se_module.fc2.weight", "layer3.1.se_module.fc2.bias", "layer3.1.bn1.weight", "layer3.1.bn1.bias", "layer3.1.bn1.running_mean", "layer3.1.bn1.running_var", "layer3.1.bn2.weight", "layer3.1.bn2.bias", "layer3.1.bn2.running_mean", "layer3.1.bn2.running_var", "layer3.2.conv3.weight", "layer3.2.bn3.weight", "layer3.2.bn3.bias", "layer3.2.bn3.running_mean", "layer3.2.bn3.running_var", "layer3.2.se_module.fc1.weight", "layer3.2.se_module.fc1.bias", "layer3.2.se_module.fc2.weight", "layer3.2.se_module.fc2.bias", "layer3.2.bn1.weight", "layer3.2.bn1.bias", "layer3.2.bn1.running_mean", "layer3.2.bn1.running_var", "layer3.2.bn2.weight", "layer3.2.bn2.bias", "layer3.2.bn2.running_mean", "layer3.2.bn2.running_var", "layer3.3.conv3.weight", "layer3.3.bn3.weight", "layer3.3.bn3.bias", "layer3.3.bn3.running_mean", "layer3.3.bn3.running_var", "layer3.3.se_module.fc1.weight", "layer3.3.se_module.fc1.bias", "layer3.3.se_module.fc2.weight", "layer3.3.se_module.fc2.bias", "layer3.3.bn1.weight", "layer3.3.bn1.bias", "layer3.3.bn1.running_mean", "layer3.3.bn1.running_var", "layer3.3.bn2.weight", "layer3.3.bn2.bias", "layer3.3.bn2.running_mean", "layer3.3.bn2.running_var", "layer3.4.conv3.weight", "layer3.4.bn3.weight", "layer3.4.bn3.bias", "layer3.4.bn3.running_mean", "layer3.4.bn3.running_var", "layer3.4.se_module.fc1.weight", "layer3.4.se_module.fc1.bias", "layer3.4.se_module.fc2.weight", "layer3.4.se_module.fc2.bias", "layer3.4.bn1.weight", "layer3.4.bn1.bias", "layer3.4.bn1.running_mean", "layer3.4.bn1.running_var", "layer3.4.bn2.weight", "layer3.4.bn2.bias", "layer3.4.bn2.running_mean", "layer3.4.bn2.running_var", "layer3.5.conv3.weight", "layer3.5.bn3.weight", "layer3.5.bn3.bias", "layer3.5.bn3.running_mean", "layer3.5.bn3.running_var", "layer3.5.se_module.fc1.weight", "layer3.5.se_module.fc1.bias", "layer3.5.se_module.fc2.weight", "layer3.5.se_module.fc2.bias", "layer3.5.bn1.weight", "layer3.5.bn1.bias", "layer3.5.bn1.running_mean", "layer3.5.bn1.running_var", "layer3.5.bn2.weight", "layer3.5.bn2.bias", "layer3.5.bn2.running_mean", "layer3.5.bn2.running_var", "layer4.0.conv3.weight", "layer4.0.bn3.weight", "layer4.0.bn3.bias", "layer4.0.bn3.running_mean", "layer4.0.bn3.running_var", "layer4.0.se_module.fc1.weight", "layer4.0.se_module.fc1.bias", "layer4.0.se_module.fc2.weight", "layer4.0.se_module.fc2.bias", "layer4.0.bn1.weight", "layer4.0.bn1.bias", "layer4.0.bn1.running_mean", "layer4.0.bn1.running_var", "layer4.0.bn2.weight", "layer4.0.bn2.bias", "layer4.0.bn2.running_mean", "layer4.0.bn2.running_var", "layer4.0.downsample.1.weight", "layer4.0.downsample.1.bias", "layer4.0.downsample.1.running_mean", "layer4.0.downsample.1.running_var", "layer4.1.conv3.weight", "layer4.1.bn3.weight", "layer4.1.bn3.bias", "layer4.1.bn3.running_mean", "layer4.1.bn3.running_var", "layer4.1.se_module.fc1.weight", "layer4.1.se_module.fc1.bias", "layer4.1.se_module.fc2.weight", "layer4.1.se_module.fc2.bias", "layer4.1.bn1.weight", "layer4.1.bn1.bias", "layer4.1.bn1.running_mean", "layer4.1.bn1.running_var", "layer4.1.bn2.weight", "layer4.1.bn2.bias", "layer4.1.bn2.running_mean", "layer4.1.bn2.running_var", "layer4.2.conv3.weight", "layer4.2.bn3.weight", "layer4.2.bn3.bias", "layer4.2.bn3.running_mean", "layer4.2.bn3.running_var", "layer4.2.se_module.fc1.weight", "layer4.2.se_module.fc1.bias", "layer4.2.se_module.fc2.weight", "layer4.2.se_module.fc2.bias", "layer4.2.bn1.weight", "layer4.2.bn1.bias", "layer4.2.bn1.running_mean", "layer4.2.bn1.running_var", "layer4.2.bn2.weight", "layer4.2.bn2.bias", "layer4.2.bn2.running_mean", "layer4.2.bn2.running_var".
size mismatch for layer1.0.conv1.weight: copying a param with shape torch.Size([128, 64, 1, 1]) from checkpoint, the shape in current model is torch.Size([64, 64, 3, 3]).
size mismatch for layer1.0.conv2.weight: copying a param with shape torch.Size([128, 4, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 64, 3, 3]).
size mismatch for layer1.1.conv1.weight: copying a param with shape torch.Size([128, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([64, 64, 3, 3]).
size mismatch for layer1.1.conv2.weight: copying a param with shape torch.Size([128, 4, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 64, 3, 3]).
size mismatch for layer1.2.conv1.weight: copying a param with shape torch.Size([128, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([64, 64, 3, 3]).
size mismatch for layer1.2.conv2.weight: copying a param with shape torch.Size([128, 4, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 64, 3, 3]).
size mismatch for layer2.0.conv1.weight: copying a param with shape torch.Size([256, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([128, 64, 3, 3]).
size mismatch for layer2.0.conv2.weight: copying a param with shape torch.Size([256, 8, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 128, 3, 3]).
size mismatch for layer2.0.downsample.0.weight: copying a param with shape torch.Size([512, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([128, 64, 1, 1]).
size mismatch for layer2.1.conv1.weight: copying a param with shape torch.Size([256, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([128, 128, 3, 3]).
size mismatch for layer2.1.conv2.weight: copying a param with shape torch.Size([256, 8, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 128, 3, 3]).
size mismatch for layer2.2.conv1.weight: copying a param with shape torch.Size([256, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([128, 128, 3, 3]).
size mismatch for layer2.2.conv2.weight: copying a param with shape torch.Size([256, 8, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 128, 3, 3]).
size mismatch for layer2.3.conv1.weight: copying a param with shape torch.Size([256, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([128, 128, 3, 3]).
size mismatch for layer2.3.conv2.weight: copying a param with shape torch.Size([256, 8, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 128, 3, 3]).
size mismatch for layer3.0.conv1.weight: copying a param with shape torch.Size([512, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([256, 128, 3, 3]).
size mismatch for layer3.0.conv2.weight: copying a param with shape torch.Size([512, 16, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 256, 3, 3]).
size mismatch for layer3.0.downsample.0.weight: copying a param with shape torch.Size([1024, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([256, 128, 1, 1]).
size mismatch for layer3.1.conv1.weight: copying a param with shape torch.Size([512, 1024, 1, 1]) from checkpoint, the shape in current model is torch.Size([256, 256, 3, 3]).
size mismatch for layer3.1.conv2.weight: copying a param with shape torch.Size([512, 16, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 256, 3, 3]).
size mismatch for layer3.2.conv1.weight: copying a param with shape torch.Size([512, 1024, 1, 1]) from checkpoint, the shape in current model is torch.Size([256, 256, 3, 3]).
size mismatch for layer3.2.conv2.weight: copying a param with shape torch.Size([512, 16, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 256, 3, 3]).
size mismatch for layer3.3.conv1.weight: copying a param with shape torch.Size([512, 1024, 1, 1]) from checkpoint, the shape in current model is torch.Size([256, 256, 3, 3]).
size mismatch for layer3.3.conv2.weight: copying a param with shape torch.Size([512, 16, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 256, 3, 3]).
size mismatch for layer3.4.conv1.weight: copying a param with shape torch.Size([512, 1024, 1, 1]) from checkpoint, the shape in current model is torch.Size([256, 256, 3, 3]).
size mismatch for layer3.4.conv2.weight: copying a param with shape torch.Size([512, 16, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 256, 3, 3]).
size mismatch for layer3.5.conv1.weight: copying a param with shape torch.Size([512, 1024, 1, 1]) from checkpoint, the shape in current model is torch.Size([256, 256, 3, 3]).
size mismatch for layer3.5.conv2.weight: copying a param with shape torch.Size([512, 16, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 256, 3, 3]).
size mismatch for layer4.0.conv1.weight: copying a param with shape torch.Size([1024, 1024, 1, 1]) from checkpoint, the shape in current model is torch.Size([512, 256, 3, 3]).
size mismatch for layer4.0.conv2.weight: copying a param with shape torch.Size([1024, 32, 3, 3]) from checkpoint, the shape in current model is torch.Size([512, 512, 3, 3]).
size mismatch for layer4.0.downsample.0.weight: copying a param with shape torch.Size([2048, 1024, 1, 1]) from checkpoint, the shape in current model is torch.Size([512, 256, 1, 1]).
size mismatch for layer4.1.conv1.weight: copying a param with shape torch.Size([1024, 2048, 1, 1]) from checkpoint, the shape in current model is torch.Size([512, 512, 3, 3]).
size mismatch for layer4.1.conv2.weight: copying a param with shape torch.Size([1024, 32, 3, 3]) from checkpoint, the shape in current model is torch.Size([512, 512, 3, 3]).
size mismatch for layer4.2.conv1.weight: copying a param with shape torch.Size([1024, 2048, 1, 1]) from checkpoint, the shape in current model is torch.Size([512, 512, 3, 3]).
size mismatch for layer4.2.conv2.weight: copying a param with shape torch.Size([1024, 32, 3, 3]) from checkpoint, the shape in current model is torch.Size([512, 512, 3, 3]).
You are trying to load the ResNext state_dict
into a ResNet, which won’t work out of the box.
To load the pretrained resnet18
parameters, you could use:
sd = models.resnet18(pretrained=True).state_dict()
model.load_state_dict(sd, strict=False)
for key in sd:
ref = sd[key]
actual = model.state_dict().get(key, None)
if actual is not None:
print((ref - actual).abs().max())
else:
print(key, ', not found in model')
The last loop is just a quick test to show, that the newly added nn.GroupNorm
layers won’t be initialized.
thank you @ptrblck
your code works for any resnet
however i want to know how do i change this line of code : model = resnet.ResNet(resnet.BasicBlock, [3, 4, 6, 3], num_classes=5, zero_init_residual=False,
groups=1, width_per_group=64, replace_stride_with_dilation=None,
norm_layer=MyGroupNorm)
to load resnext50?
i can’t pass norm_layer=MyGroupNorm for resnext,if you show an example for resnext50 then it could be great! thank you a lot
@ptrblck i need 1 help from you!
i tried your bn to gn swapping code and it works but here is what happening with me :
- without groupnormalization swapping my model works fine
- when i swap bn layers with gn(like you did in above message) i see qwk score getting 0.0 sometimes but with batchnorm layers i never get qwk score,why is this happening?
let me repeat again : “everything works fine when i don’t swap bn layers with gn layers”
when i use gn layers my model starts performing very badly but groupnorm paper is telling different story!! what could be the reason for this? any suggestion or groupnorm modification for solving this issue?
I don’t know, why GroupNorm
layers might decrease the model performance.
How did you swap them and verify that the model is correctly using these layers?
@ptrblck i just changed my efficientnet model with your model like you posted above,
my previous model that was working fine :
class enetv2(nn.Module):
def __init__(self, backbone, out_dim):
super(enetv2, self).__init__()
self.enet = enet.EfficientNet.from_name(backbone)
self.enet.load_state_dict(torch.load(pretrained_model[backbone]))
self.myfc = nn.Linear(self.enet._fc.in_features, out_dim)
self.enet._fc = nn.Identity()
def extract(self, x):
return self.enet(x)
def forward(self, x):
x = self.extract(x)
x = self.myfc(x)
return x
here out_dim = 5
your model :
from torchvision.models import resnet
class MyGroupNorm(nn.Module):
def __init__(self, num_channels):
super(MyGroupNorm, self).__init__()
self.norm = nn.GroupNorm(num_groups=2, num_channels=num_channels,
eps=1e-5, affine=True)
def forward(self, x):
x = self.norm(x)
return x
model = resnet.ResNet(resnet.BasicBlock, [2, 2, 2, 2], num_classes=5, zero_init_residual=False,
groups=1, width_per_group=64, replace_stride_with_dilation=None,
norm_layer=MyGroupNorm)
print(model)
i just simply changed my efficientnet model with your model so that i can use groupnorm and everything else was as before!
Hi
when I use your method I get the following error:
Traceback (most recent call last):
File "train_discriminator.py", line 170, in <module>
norm_layer=MyGroupNorm)
File "/alto/nima/torch-env/lib/python3.6/site-packages/torchvision/models/resnet.py", line 192, in __init__
nn.init.constant_(m.weight, 1)
File "/alto/nima/torch-env/lib/python3.6/site-packages/torch/nn/init.py", line 176, in constant_
return _no_grad_fill_(tensor, val)
File "/alto/nima/torch-env/lib/python3.6/site-packages/torch/nn/init.py", line 59, in _no_grad_fill_
return tensor.fill_(val)
AttributeError: 'NoneType' object has no attribute 'fill_'
You would have to handle MyGroupNorm
differently, as this custom modules doesn’t contain the .weight
attribute, but instead registers .norm
, which has the weight
and bias
, if affine
was set to True
.
As a quick fix, you could check, if m.weight
is not None
and initialize it then.