This is a beautiful solution. I veryfied based on your idea and it works pretty well.
import torch
import torch.nn as nn
import torch.nn.functional as F
class ModelA(torch.nn.Module):
def __init__(self):
super(ModelA,self).__init__()
self.A = torch.nn.Linear(2, 3)
self.B = torch.nn.Linear(3, 4)
self.C = torch.nn.Linear(4, 4)
self.D = torch.nn.Linear(4, 3)
def forward(self, x):
x = F.relu(self.A(x))
x = F.relu(self.B(x))
x = F.relu(self.C(x))
x = F.relu(self.D(x))
return x
class ModelB(torch.nn.Module):
def __init__(self):
super(ModelB,self).__init__()
self.A = torch.nn.Linear(2, 3)
self.B = torch.nn.Linear(3, 4)
self.C = torch.nn.Linear(4, 4)
self.E = torch.nn.Linear(4, 2)
def forward(self, x):
x = F.relu(self.A(x))
x = F.relu(self.B(x))
x = F.relu(self.C(x))
x = F.relu(self.E(x))
return x
modelA = ModelA()
modelA_dict = modelA.state_dict()
print('-'*40)
for key in sorted(modelA_dict.keys()):
parameter = modelA_dict[key]
print(key)
print(parameter.size())
print(parameter)
modelB = ModelB()
modelB_dict = modelB.state_dict()
print('-'*40)
for key in sorted(modelB_dict.keys()):
parameter = modelB_dict[key]
print(key)
print(parameter.size())
print(parameter)
print('-'*40)
print("modelB is going to use the ABC layers parameters from modelA")
pretrained_dict = modelA_dict
model_dict = modelB_dict
# 1. filter out unnecessary keys
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
# 3. load the new state dict
modelB.load_state_dict(model_dict)
modelB_dict = modelB.state_dict()
for key in sorted(modelB_dict.keys()):
parameter = modelB_dict[key]
print(key)
print(parameter.size())
print(parameter)
Output
----------------------------------------
A.bias
torch.Size([3])
tensor([ 0.4012, -0.3587, 0.6650])
A.weight
torch.Size([3, 2])
tensor([[ 0.5574, 0.4757],
[-0.3795, -0.4850],
[ 0.2248, -0.3578]])
B.bias
torch.Size([4])
tensor([ 0.1353, -0.3448, 0.4272, -0.1463])
B.weight
torch.Size([4, 3])
tensor([[-0.4960, 0.2930, 0.1822],
[-0.4309, -0.4259, -0.3604],
[ 0.2976, 0.2279, -0.3805],
[-0.2423, -0.2915, 0.5130]])
C.bias
torch.Size([4])
tensor([-0.2964, -0.3516, -0.2900, 0.2390])
C.weight
torch.Size([4, 4])
tensor([[ 0.0877, 0.4150, -0.1938, 0.3659],
[-0.3505, 0.1734, -0.1803, 0.2914],
[ 0.3375, -0.2661, 0.4651, 0.0041],
[-0.1866, 0.0055, 0.0230, 0.0502]])
D.bias
torch.Size([3])
tensor([0.2733, 0.3856, 0.2848])
D.weight
torch.Size([3, 4])
tensor([[ 0.4498, 0.4846, -0.2461, 0.1043],
[-0.1462, -0.1684, 0.0155, -0.2861],
[-0.2750, 0.3607, 0.4295, -0.3481]])
----------------------------------------
A.bias
torch.Size([3])
tensor([-0.2486, -0.3553, -0.3503])
A.weight
torch.Size([3, 2])
tensor([[ 0.1880, -0.6102],
[-0.1288, 0.6273],
[ 0.1040, -0.5014]])
B.bias
torch.Size([4])
tensor([ 0.2349, 0.1911, -0.5200, -0.1111])
B.weight
torch.Size([4, 3])
tensor([[ 0.3223, 0.4178, -0.1244],
[-0.2392, 0.5335, -0.4440],
[-0.4544, 0.3134, 0.1886],
[-0.3317, 0.2892, -0.5672]])
C.bias
torch.Size([4])
tensor([ 0.4484, 0.3125, -0.1636, -0.1316])
C.weight
torch.Size([4, 4])
tensor([[-0.1965, -0.3447, -0.4057, -0.2020],
[-0.3002, 0.0170, -0.0360, 0.2502],
[ 0.3630, -0.2502, 0.2334, -0.1819],
[ 0.1432, 0.1483, -0.2965, -0.0004]])
E.bias
torch.Size([2])
tensor([-0.1594, 0.4471])
E.weight
torch.Size([2, 4])
tensor([[ 0.0461, -0.3409, 0.3723, -0.1613],
[-0.0548, 0.3238, -0.2238, 0.1237]])
----------------------------------------
modelB is going to use the ABC layers parameters from modelA
A.bias
torch.Size([3])
tensor([ 0.4012, -0.3587, 0.6650])
A.weight
torch.Size([3, 2])
tensor([[ 0.5574, 0.4757],
[-0.3795, -0.4850],
[ 0.2248, -0.3578]])
B.bias
torch.Size([4])
tensor([ 0.1353, -0.3448, 0.4272, -0.1463])
B.weight
torch.Size([4, 3])
tensor([[-0.4960, 0.2930, 0.1822],
[-0.4309, -0.4259, -0.3604],
[ 0.2976, 0.2279, -0.3805],
[-0.2423, -0.2915, 0.5130]])
C.bias
torch.Size([4])
tensor([-0.2964, -0.3516, -0.2900, 0.2390])
C.weight
torch.Size([4, 4])
tensor([[ 0.0877, 0.4150, -0.1938, 0.3659],
[-0.3505, 0.1734, -0.1803, 0.2914],
[ 0.3375, -0.2661, 0.4651, 0.0041],
[-0.1866, 0.0055, 0.0230, 0.0502]])
E.bias
torch.Size([2])
tensor([-0.1594, 0.4471])
E.weight
torch.Size([2, 4])
tensor([[ 0.0461, -0.3409, 0.3723, -0.1613],
[-0.0548, 0.3238, -0.2238, 0.1237]])