Hello,
I was experimenting with load_state_dict with strict = False. Before my asking my query, I will post my model code here
class AutoEncoder(nn.Module):
def __init__(self, input_dim):
super(AutoEncoder, self).__init__()
self.input_dim = input_dim
# encoder
self.batchnorm1 = nn.BatchNorm1d(self.input_dim)
self.linear1 = nn.Linear(self.input_dim, 64)
self.batchnorm2 = nn.BatchNorm1d(64)
self.linear2 = nn.Linear(64, 32)
# decoder
self.batchnorm3 = nn.BatchNorm1d(32)
self.linear3 = nn.Linear(32, 64)
self.batchnorm4 = nn.BatchNorm1d(64)
self.linear4 = nn.Linear(64, self.input_dim)
# relu
self.relu = nn.ReLU()
def forward(self, h0):
# encoder forward pass
h1 = self.relu(self.linear1(self.batchnorm1(h0)))
h2 = self.linear2(self.batchnorm2(h1))
# decoder forward pass
h3 = self.relu(self.linear3(self.batchnorm3(h2)))
h4 = self.linear4(self.batchnorm4(h4))
return h4
class Classifier(nn.Module):
def __init__(self, input_dim, output_dim):
super(Classifier, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
# encoder
self.batchnorm1 = nn.BatchNorm1d(self.input_dim)
self.linear1 = nn.Linear(self.input_dim, 64)
self.batchnorm2 = nn.BatchNorm1d(64)
self.linear2 = nn.Linear(64, 32)
self.batchnorm = nn.BatchNorm1d(32)
self.classifer = nn.Linear(32, self.output_dim)
def forward(self, h0):
h1 = self.relu(self.linear1(self.batchnorm1(h0)))
h2 = self.linear2(self.batchnorm2(h1))
out = self.classifier(self.batchnorm(h2))
return out
After this I intialized both the models as shown below
autoencoder_model = AutoEncoder(130)
classifier_model = Classifier(130, 1)
After this I tried this classifier_model.load_state_dict(autoencoder_model.state_dict(), strict=False)
. That is it will load the encoder part of autoencoder weights to the classifier. This is the output I got (which shows the missing keys and unexpected keys)
_IncompatibleKeys(missing_keys=['batchnorm.weight', 'batchnorm.bias',
'batchnorm.running_mean', 'batchnorm.running_var', 'classifer.weight', 'classifer.bias'],
unexpected_keys=['batchnorm3.weight', 'batchnorm3.bias', 'batchnorm3.running_mean',
'batchnorm3.running_var', 'batchnorm3.num_batches_tracked', 'linear3.weight',
'linear3.bias', 'batchnorm4.weight', 'batchnorm4.bias', 'batchnorm4.running_mean',
'batchnorm4.running_var', 'batchnorm4.num_batches_tracked', 'linear4.weight',
'linear4.bias'])
The output also looks correct, but my query is why does “batchnorm.num_batches_tracked” didn’t come in missing keys ?