Hi @ptrblck,
I wasn’t sure what you meant, but I made a fully contained example:
import torch
from torch import nn
import csv
import copy
class Net(nn.Module):
def __init__(self,old_state_dict,state_dict_map):
# run nn.Module's constructor
super(Net,self).__init__()
#--------------------------------------------------------------------
# build net
in_channels = 16
conv1 = nn.Conv2d(in_channels = 1,
out_channels = in_channels,
kernel_size = (3,3),
padding = (1,1))
batch_norm = nn.BatchNorm2d(num_features = 16)
activation = nn.ReLU()
conv2 = nn.Conv2d(in_channels = in_channels,
out_channels = in_channels,
kernel_size = (3,3),
padding = (1,1))
pooling = nn.MaxPool2d(kernel_size = (2,2))
# first stage
stages = [nn.Sequential(conv1,
batch_norm,
activation,
conv2,
batch_norm,
activation,
pooling)]
# next 4 stages
for i in range(4):
conv1 = nn.Conv2d(in_channels = in_channels,
out_channels = in_channels * 2,
kernel_size = (3,3),
padding = (1,1))
batch_norm = nn.BatchNorm2d(num_features = in_channels * 2)
conv2 = nn.Conv2d(in_channels = in_channels * 2,
out_channels = in_channels * 2,
kernel_size = (3,3),
padding = (1,1))
stages += [nn.Sequential(conv1,
batch_norm,
activation,
conv2,
batch_norm,
activation,
pooling)]
in_channels = in_channels * 2
# 6th stage, in_channels = 256
conv1 = nn.Conv2d(in_channels = in_channels,
out_channels = in_channels * 2,
kernel_size = (3,3),
padding = (1,1))
batch_norm = nn.BatchNorm2d(num_features = in_channels * 2)
stages += [nn.Sequential(conv1,
batch_norm,
activation,
pooling)]
in_channels = in_channels * 2
# final stage, in_channels = 512
conv1 = nn.Conv2d(in_channels = in_channels,
out_channels = in_channels * 2,
kernel_size = (2,2),
padding = (0,0))
batch_norm = nn.BatchNorm2d(num_features = in_channels * 2)
stages += [nn.Sequential(conv1,
batch_norm,
activation)]
# assign names to the stages for the state_dict
self.stage1 = stages[0]
self.stage2 = stages[1]
self.stage3 = stages[2]
self.stage4 = stages[3]
self.stage5 = stages[4]
self.stage6 = stages[5]
self.stage7 = stages[6]
#--------------------------------------------------------------------
# load the pre-trained parameters
# make a copy to use load_state_dict() method later
state_dict = copy.deepcopy(self.state_dict())
for new_key,old_key in state_dict_map.items():
state_dict[new_key] = old_state_dict[old_key]
self.load_state_dict(state_dict)
def forward(self,x):
x = self.stage1(x)
x = self.stage2(x)
x = self.stage3(x)
x = self.stage4(x)
x = self.stage5(x)
x = self.stage6(x)
x = self.stage7(x)
# average of all pixels in each feature map
x = nn.functional.avg_pool2d(input = x,
kernel_size = x.shape[2:])
# flatten from N x 1024 x 1 x 1 to N x 1024
x = torch.flatten(input = x,
start_dim = 1)
return x
torch.manual_seed(42)
old_state_dict = {'module.layer1.0.weight':torch.randn((16, 1, 3, 3)),
'module.layer1.0.bias':torch.randn((16)),
'module.layer1.1.weight':torch.randn((16)),
'module.layer1.1.bias':torch.randn((16)),
'module.layer1.1.running_mean':torch.randn((16)),
'module.layer1.1.running_var':torch.randn((16)),
'module.layer2.0.weight':torch.randn((16, 16, 3, 3)),
'module.layer2.0.bias':torch.randn((16)),
'module.layer2.1.weight':torch.randn((16)),
'module.layer2.1.bias':torch.randn((16)),
'module.layer2.1.running_mean':torch.randn((16)),
'module.layer2.1.running_var':torch.randn((16)),
'module.layer4.0.weight':torch.randn((32, 16, 3, 3)),
'module.layer4.0.bias':torch.randn((32)),
'module.layer4.1.weight':torch.randn((32)),
'module.layer4.1.bias':torch.randn((32)),
'module.layer4.1.running_mean':torch.randn((32)),
'module.layer4.1.running_var':torch.randn((32)),
'module.layer5.0.weight':torch.randn((32, 32, 3, 3)),
'module.layer5.0.bias':torch.randn((32)),
'module.layer5.1.weight':torch.randn((32)),
'module.layer5.1.bias':torch.randn((32)),
'module.layer5.1.running_mean':torch.randn((32)),
'module.layer5.1.running_var':torch.randn((32)),
'module.layer7.0.weight':torch.randn((64, 32, 3, 3)),
'module.layer7.0.bias':torch.randn((64)),
'module.layer7.1.weight':torch.randn((64)),
'module.layer7.1.bias':torch.randn((64)),
'module.layer7.1.running_mean':torch.randn((64)),
'module.layer7.1.running_var':torch.randn((64)),
'module.layer8.0.weight':torch.randn((64, 64, 3, 3)),
'module.layer8.0.bias':torch.randn((64)),
'module.layer8.1.weight':torch.randn((64)),
'module.layer8.1.bias':torch.randn((64)),
'module.layer8.1.running_mean':torch.randn((64)),
'module.layer8.1.running_var':torch.randn((64)),
'module.layer10.0.weight':torch.randn((128, 64, 3, 3)),
'module.layer10.0.bias':torch.randn((128)),
'module.layer10.1.weight':torch.randn((128)),
'module.layer10.1.bias':torch.randn((128)),
'module.layer10.1.running_mean':torch.randn((128)),
'module.layer10.1.running_var':torch.randn((128)),
'module.layer11.0.weight':torch.randn((128, 128, 3, 3)),
'module.layer11.0.bias':torch.randn((128)),
'module.layer11.1.weight':torch.randn((128)),
'module.layer11.1.bias':torch.randn((128)),
'module.layer11.1.running_mean':torch.randn((128)),
'module.layer11.1.running_var':torch.randn((128)),
'module.layer13.0.weight':torch.randn((256, 128, 3, 3)),
'module.layer13.0.bias':torch.randn((256)),
'module.layer13.1.weight':torch.randn((256)),
'module.layer13.1.bias':torch.randn((256)),
'module.layer13.1.running_mean':torch.randn((256)),
'module.layer13.1.running_var':torch.randn((256)),
'module.layer14.0.weight':torch.randn((256, 256, 3, 3)),
'module.layer14.0.bias':torch.randn((256)),
'module.layer14.1.weight':torch.randn((256)),
'module.layer14.1.bias':torch.randn((256)),
'module.layer14.1.running_mean':torch.randn((256)),
'module.layer14.1.running_var':torch.randn((256)),
'module.layer16.0.weight':torch.randn((512, 256, 3, 3)),
'module.layer16.0.bias':torch.randn((512)),
'module.layer16.1.weight':torch.randn((512)),
'module.layer16.1.bias':torch.randn((512)),
'module.layer16.1.running_mean':torch.randn((512)),
'module.layer16.1.running_var':torch.randn((512)),
'module.layer18.0.weight':torch.randn((1024, 512, 2, 2)),
'module.layer18.0.bias':torch.randn((1024)),
'module.layer18.1.weight':torch.randn((1024)),
'module.layer18.1.bias':torch.randn((1024)),
'module.layer18.1.running_mean':torch.randn((1024)),
'module.layer18.1.running_var':torch.randn((1024))}
state_dict_map = {'stage1.0.weight':'module.layer1.0.weight',
'stage1.0.bias':'module.layer1.0.bias',
'stage1.1.weight':'module.layer1.1.weight',
'stage1.1.bias':'module.layer1.1.bias',
'stage1.1.running_mean':'module.layer1.1.running_mean',
'stage1.1.running_var':'module.layer1.1.running_var',
'stage1.3.weight':'module.layer2.0.weight',
'stage1.3.bias':'module.layer2.0.bias',
'stage1.4.weight':'module.layer2.1.weight',
'stage1.4.bias':'module.layer2.1.bias',
'stage1.4.running_mean':'module.layer2.1.running_mean',
'stage1.4.running_var':'module.layer2.1.running_var',
'stage2.0.weight':'module.layer4.0.weight',
'stage2.0.bias':'module.layer4.0.bias',
'stage2.1.weight':'module.layer4.1.weight',
'stage2.1.bias':'module.layer4.1.bias',
'stage2.1.running_mean':'module.layer4.1.running_mean',
'stage2.1.running_var':'module.layer4.1.running_var',
'stage2.3.weight':'module.layer5.0.weight',
'stage2.3.bias':'module.layer5.0.bias',
'stage2.4.weight':'module.layer5.1.weight',
'stage2.4.bias':'module.layer5.1.bias',
'stage2.4.running_mean':'module.layer5.1.running_mean',
'stage2.4.running_var':'module.layer5.1.running_var',
'stage3.0.weight':'module.layer7.0.weight',
'stage3.0.bias':'module.layer7.0.bias',
'stage3.1.weight':'module.layer7.1.weight',
'stage3.1.bias':'module.layer7.1.bias',
'stage3.1.running_mean':'module.layer7.1.running_mean',
'stage3.1.running_var':'module.layer7.1.running_var',
'stage3.3.weight':'module.layer8.0.weight',
'stage3.3.bias':'module.layer8.0.bias',
'stage3.4.weight':'module.layer8.1.weight',
'stage3.4.bias':'module.layer8.1.bias',
'stage3.4.running_mean':'module.layer8.1.running_mean',
'stage3.4.running_var':'module.layer8.1.running_var',
'stage4.0.weight':'module.layer10.0.weight',
'stage4.0.bias':'module.layer10.0.bias',
'stage4.1.weight':'module.layer10.1.weight',
'stage4.1.bias':'module.layer10.1.bias',
'stage4.1.running_mean':'module.layer10.1.running_mean',
'stage4.1.running_var':'module.layer10.1.running_var',
'stage4.3.weight':'module.layer11.0.weight',
'stage4.3.bias':'module.layer11.0.bias',
'stage4.4.weight':'module.layer11.1.weight',
'stage4.4.bias':'module.layer11.1.bias',
'stage4.4.running_mean':'module.layer11.1.running_mean',
'stage4.4.running_var':'module.layer11.1.running_var',
'stage5.0.weight':'module.layer13.0.weight',
'stage5.0.bias':'module.layer13.0.bias',
'stage5.1.weight':'module.layer13.1.weight',
'stage5.1.bias':'module.layer13.1.bias',
'stage5.1.running_mean':'module.layer13.1.running_mean',
'stage5.1.running_var':'module.layer13.1.running_var',
'stage5.3.weight':'module.layer14.0.weight',
'stage5.3.bias':'module.layer14.0.bias',
'stage5.4.weight':'module.layer14.1.weight',
'stage5.4.bias':'module.layer14.1.bias',
'stage5.4.running_mean':'module.layer14.1.running_mean',
'stage5.4.running_var':'module.layer14.1.running_var',
'stage6.0.weight':'module.layer16.0.weight',
'stage6.0.bias':'module.layer16.0.bias',
'stage6.1.weight':'module.layer16.1.weight',
'stage6.1.bias':'module.layer16.1.bias',
'stage6.1.running_mean':'module.layer16.1.running_mean',
'stage6.1.running_var':'module.layer16.1.running_var',
'stage7.0.weight':'module.layer18.0.weight',
'stage7.0.bias':'module.layer18.0.bias',
'stage7.1.weight':'module.layer18.1.weight',
'stage7.1.bias':'module.layer18.1.bias',
'stage7.1.running_mean':'module.layer18.1.running_mean',
'stage7.1.running_var':'module.layer18.1.running_var'}
net = Net(old_state_dict,state_dict_map)
for key,value in net.state_dict().items():
if 'num_batches_tracked' in key:
continue
param1 = value
param2 = old_state_dict[state_dict_map[key]]
is_equal = torch.allclose(param1,param2)
print(is_equal)
However, after I run this, I see the following in the console:
True
True
False
False
False
False
True
True
True
True
True
True
True
True
False
False
False
False
True
True
True
True
True
True
True
True
False
False
False
False
True
True
True
True
True
True
True
True
False
False
False
False
True
True
True
True
True
True
True
True
False
False
False
False
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
Which means that some parameter values in net.state_dict()
are not correct. For example, the first False
that you see above corresponds to the 'stage1.1.weight'
key. Comparing these two parameters:
>>> net.state_dict()['stage1.1.weight']
tensor([ 1.3956, -0.4016, -0.4760, 0.6024, -0.1390, -0.5199, -0.4298,
-0.9330,-0.3256, 0.9304, -0.2840, 0.8464, 0.0186, -1.6756, -1.9437, 0.0987])
>>> old_state_dict[state_dict_map['stage1.1.weight']]
tensor([-1.9006, 0.2286, 0.0249, -0.3460, 0.2868, -0.7308, 0.1748,
-1.0939,-1.6022, 1.3529, 1.2888, 0.0523, -1.5469, 0.7567, 0.7755, 2.0265])
Again, I am not sure where these numbers came from. In this case, net.state_dict()['stage1.1.weight']
should take the value of old_state_dict[state_dict_map['stage1.1.weight']]
.
I also want to add that while debugging here:
for new_key,old_key in state_dict_map.items():
state_dict[new_key] = old_state_dict[old_key]
In the case where new_key == 'stage1.1.weight'
and old_key == 'module.layer1.1.weight'
, then state_dict[new_key]
was indeed equal to old_state_dict[old_key]
, which suggests that the problem is with self.load_state_dict(state_dict)
.
I would greatly appreciate your feedback on this. Thanks a lot.