Hi,
I was just wondering, there is a small difference whenever we save state dictionary of a model and reload it back. Can you please point if I am doing some mistake or if this is an insignificant issue?
Regards,
P.S: I have pasted sample code below
from torchvision import models
import torch
model = models.resnet50(pretrained=True)
model.eval()
stateDict = model.state_dict()
torch.save(model.state_dict(), “test.pth”)
reloadDict = torch.load(“test.pth”)
diff = stateDict.items() - reloadDict.items()
If I’m not mistaken, subtracting dict.items()
would try to remove duplicated entries as seen here:
d1 = {'a': 1, 'b': 2, 'c': 10}
d2 = {'a': 3, 'b': 4}
print(d1.items() - d2.items())
> {('a', 1), ('b', 2), ('c', 10)}
d1 = {'a': 1, 'b': 2, 'c': 10}
d2 = {'a': 3, 'b': 4, 'c': 10}
print(d1.items() - d2.items())
> {('a', 1), ('b', 2)}
so compare the values via:
model = models.resnet50(pretrained=True)
model.eval()
stateDict = model.state_dict()
torch.save(model.state_dict(), 'test.pth')
reloadDict = torch.load('test.pth')
diff = stateDict.items() - reloadDict.items()
len(diff)
for key in stateDict:
paramA = stateDict[key]
paramB = reloadDict[key]
print(key, (paramA - paramB).abs().max())
which will show a zero difference.
1 Like
Thanks a bunch, I made that silly assumption, will be careful next time in asking question!
Thanks a heap!