import torch
from LeNet5 import LeNet5
model = torch.load('LeNet5.pth') # 加载模型
layerWeight = model.state_dict()['conv1.0.bias'].clone()
print("----layerWeight is ",layerWeight)
print("----model is ", model.state_dict()['conv1.0.bias'])
model.state_dict()['conv1.0.bias'][0] = 1
print("++++layerWeight is", layerWeight)
print("++++model is", model.state_dict()['conv1.0.bias'])
print("layerWeight type is ",type(layerWeight))
print("model type is",type(model.state_dict()['conv1.0.bias']))
model.state_dict['conv1.0.bias'] = layerWeight
print("****layerWeight is",layerWeight)
print("****model is", model.state_dict()['conv1.0.bias'])
Above the code, I try model.state_dict[‘conv1.0.bias’] = layerWeight,(model.state_dict[‘conv1.0.bias’] and layerWeight are both the type of tensor)but get the error in the picture below.
Please hope you could tell me how to solve it.