AttributeError: 'VGG' object has no attribute 'copy'

please find code snippet below:-

def load_torch_model(self):
    if self.pre_trained_model == 'VGG16':
        self.torch_model = models.vgg16(pretrained=True)
        print(self.torch_model)
        self.torch_model.classifier[6] = nn.Sequential(nn.Linear(in_features=self.in_features[0],
                                                                 out_features=self.out_features[0], bias=True),
                                                       nn.ReLU(), nn.Dropout(p=self.dropout_ratio[0]),
                                                       nn.Linear(in_features=self.in_features[1],
                                                                 out_features=self.out_features[1], bias=True),
                                                       nn.LogSoftmax(dim=1))
        print(self.torch_model)

    self.torch_model.load_state_dict(torch.load(<Path of model with .pt extension>))

I got traceback:-

Traceback (most recent call last):
File “C:/AI_Projects/Pytorch_to_TensorFlow_API/test.py”, line 20, in
convert.load_torch_model()
File “C:\AI_Projects\Pytorch_to_TensorFlow_API\converter.py”, line 43, in load_torch_model
self.torch_model.load_state_dict(torch.load(self.torch_model_path))
File “C:\Python37\lib\site-packages\torch\nn\modules\module.py”, line 818, in load_state_dict
state_dict = state_dict.copy()
File “C:\Python37\lib\site-packages\torch\nn\modules\module.py”, line 591, in getattr
type(self).name, name))
AttributeError: ‘VGG’ object has no attribute ‘copy’

How did you save the state_dict and what keys are inside it?
Saving and loading the state_dict using your model, works fine:

torch_model = models.vgg16(pretrained=False)

torch_model.classifier[6] = nn.Sequential(
    nn.Linear(in_features=4096,
        out_features=1, bias=True),
    nn.ReLU(),
    nn.Dropout(p=0.5),
    nn.Linear(in_features=1, out_features=2, bias=True),
    nn.LogSoftmax(dim=1)
)

torch_model(torch.randn(1, 3, 224, 224))
state_dict = torch_model.state_dict()
torch.save(state_dict, 'tmp.pt')

torch_model.load_state_dict(torch.load('tmp.pt'))
4 Likes