Hey! I am loading the same network (before even training, right after the initialization) from a pickle file, but it has different parameters every time.
I tried 2 different ways of saving and loading:
net = …
with open(‘train.pickle’, ‘wb’) as f:
pickle.dump([net], f)
then:
with open(‘train.pickle’, ‘rb’) as f:
net1 = pickle.load(f)
with open(‘train.pickle’, ‘rb’) as f:
net2 = pickle.load(f)
net1=net1[0]
net2=net2[0]
print(net1.eval()==net2.eval()) //gives False
if net1.parameters() != net2.parameters():
print(True) //gives True.
I also tried:
torch.save(model, PATH)
model = torch.load(PATH)
Hello
The best practice in pytorch is to save a model state_dict. Consider this minimal example.
import torch
import torch.nn as nn
m = nn.Linear(10, 2)
state = m.state_dict()
torch.save(state, 'm.pt')
m1 = nn.Linear(10, 2)
m2 = nn.Linear(10, 2)
for p1, p2 in zip(m1.parameters(), m2.parameters()):
print(torch.all(p1 == p2))
prints:
tensor(False)
tensor(False)
and
state = torch.load('m.pt')
m1.load_state_dict(state)
m2.load_state_dict(state)
for p1, p2 in zip(m1.parameters(), m2.parameters()):
print(torch.all(p1 == p2))
Thank you for your answer.
But, it didn’t work.
It gives true for a very few number of parameters, but not all of them.
I am using google colab btw (if it affects anything).
I run this example in colab. Gives me all parameters of m1 and m2 are equal after loading the same state dict into them. Is running the exact code from my previous post gives you different results?
I don’t see how you can come to this conclusion from this code. What it is doing. It is just calculating the number of parameters in m1 model and nothing to compare it with m2
Here you are comparing two generators (and as they are 2 different python objects they are obviously not the same). You can see the type: type(m1.parameters()). That is why you need to iterate and compare parameters themself.
Also in this simple model you can just print parameters for both models and make sure they are the same:
for p1, p2 in zip(m1.parameters(), m2.parameters()):
print("m1")
print(p1)
print("m2")
print(p2)