trained_network = ...
torch.save(trained_network.state_dict(), 'final-model.pt')
new_network = init_model(...)
new_network.load_state_dict(torch.load('final-model.pt'))
evaluation(new_network) # new_network performs bad, just like a model without training
I would like to use latter approach as it is the recommended way. How can I fix it ?
I figure out this problem now. I first compare the parameters name and values using the following code:
trained_network = Net() # network with trained parameters
my_network = Net() # network with default initialization
my_network.load_state_dict(trained_network.state_dict())
for ((k1, v1), (k2, v2)) in zip(my_network.state_dict().items(), trained_network.state_dict().items()):
assert k1 == k2, "Parameter name not match"
if not torch.equal(v1, v2):
print("Parameter value not match", k1)
And then I found the parameter values are different in following module:
When I use torch.load(...), the above is_inited is set to True, while is_inited is set to False when using my_network.load_state_dict(...). As a result, the loaded parameters are always overwrited whenever I use my_network.load_state_dict(), which cause to bad performance.