I know torch.save(net.state_dict(), model_path)
can only save the weights, and torch.save(net, model_path)
can save the weights and the structure of the network.
I know ·torch.load(model_path)· can load the pretrained model.
I know net.load_state_dict(torch.load(model_path))
can map the pretrained weights to the current net.
But I found it is necessary to define the net before I load the net from disk. Otherwise, error would accur: AttributeError: Can't get attribute 'VGG19' on <module '__main__'>
. I know this is because I know torch.save()
is a wraper of python
's pickle
, and the limitation is because of pickle
itself .
My question is:
How to load models if I do not know its structure/code?