Hi. I am trying to load a model with:
# Save the model checkpoint
torch.save(model.state_dict(), 'model.ckpt')
# 创建一个与原始模型具有相同结构的模型实例
model = model(inputs)
# 加载模型的状态字典
model.load_state_dict(torch.load('model.ckpt'))
But it says:
AttributeError: 'Tensor' object has no attribute 'load_state_dict'
model(inputs)
will return a tensor in common use cases.
Since you are re-assigning the model
variable to it:
model = model(inputs)
model
will now be a tensor instead of an nn.Module
instance and will thus fail in the next line of code.
Assign the output to another variable and it should work, e.g.:
output = model(input)