please find code snippet below:-
def load_torch_model(self):
if self.pre_trained_model == 'VGG16':
self.torch_model = models.vgg16(pretrained=True)
print(self.torch_model)
self.torch_model.classifier[6] = nn.Sequential(nn.Linear(in_features=self.in_features[0],
out_features=self.out_features[0], bias=True),
nn.ReLU(), nn.Dropout(p=self.dropout_ratio[0]),
nn.Linear(in_features=self.in_features[1],
out_features=self.out_features[1], bias=True),
nn.LogSoftmax(dim=1))
print(self.torch_model)
self.torch_model.load_state_dict(torch.load(<Path of model with .pt extension>))
I got traceback:-
Traceback (most recent call last):
File “C:/AI_Projects/Pytorch_to_TensorFlow_API/test.py”, line 20, in
convert.load_torch_model()
File “C:\AI_Projects\Pytorch_to_TensorFlow_API\converter.py”, line 43, in load_torch_model
self.torch_model.load_state_dict(torch.load(self.torch_model_path))
File “C:\Python37\lib\site-packages\torch\nn\modules\module.py”, line 818, in load_state_dict
state_dict = state_dict.copy()
File “C:\Python37\lib\site-packages\torch\nn\modules\module.py”, line 591, in getattr
type(self).name, name))
AttributeError: ‘VGG’ object has no attribute ‘copy’