I want to load a model saved by torch.jit.save and then convert the model to other deep learning framework.
It is necessary to extract graph and weight/bias and other tensors from graph.
I only know that the weith/bias tensor is stored in state_dict
.
However, in pytorch 1.5.0, the order of torch._C.Value
in graph.inputs() and the order of state_dict
keys are not same.
For example, for mobilenetv2 model in torchvision, the first few lines of its graph generated by torch.jit is:
graph(%input.10 : Float(10, 3, 224, 224),
%1 : Float(1280, 320, 1, 1),
%2 : Float(1280),
%3 : Float(1280),
%4 : Float(1280),
%5 : Float(1280),
%6 : Float(160, 960, 1, 1),
%7 : Float(160),
%8 : Float(160),
%9 : Float(160)
but the first few keys of state_dict is
['features.0.0.weight', 'features.0.1.weight', 'features.0.1.bias', 'features.0.1.running_mean', 'features.0.1.running_var']
and I have
model.state_dict()[list(model.state_dict().keys())[0]].shape
torch.Size([32, 3, 3, 3])
model.state_dict()[list(model.state_dict().keys())[1]].shape
torch.Size([32])
model.state_dict()[list(model.state_dict().keys())[2]].shape
torch.Size([32])
It is obvious that the order in graph.inputs() and order in state_dict are not the same.
So, is there any method to correctly extract the Tenso/weight/bias values from torch._C.Graph
?