Can I change a custom resnet 18 architecture and still use it in pre-trained = true mode? I am doing a subtle change in the architecture of a custom resnet18 and when i run it, i get the following error:
This is how the custom resnet18 is called:
model = Resnet_18.resnet18(pretrained=True, embedding_size=args.dim_embed)
The new change in the custom resnet18:
self.layer_attend1 = nn.Sequential(nn.Conv2d(layers[0], layers[0], stride=2, padding=1, kernel_size=3),
nn.AdaptiveAvgPool2d(1),
nn.Softmax(1))
The output of running the model is:
/scratch3/venv/fashcomp/lib/python3.8/site-packages/torchvision/transforms/transforms.py:310: UserWarning: The use of the transforms.Scale transform is deprecated, please use transforms.Resize instead.
warnings.warn("The use of the transforms.Scale transform is deprecated, " +
=> loading checkpoint 'runs/nondisjoint_l2norm/model_best.pth.tar'
Traceback (most recent call last):
File "main.py", line 352, in <module>
main()
File "main.py", line 145, in main
tnet.load_state_dict(checkpoint['state_dict'])
File "/scratch3/venv/fashcomp/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1406, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for Tripletnet:
Missing key(s) in state_dict: "embeddingnet.embeddingnet.layer_attend1.0.weight", "embeddingnet.embeddingnet.layer_attend1.0.bias". /scratch3/venv/fashcomp/lib/python3.8/site-packages/torchvision/transforms/transforms.py:310: UserWarning: The use of the transforms.Scale transform is deprecated, please use transforms.Resize instead.
warnings.warn("The use of the transforms.Scale transform is deprecated, " +
=> loading checkpoint 'runs/nondisjoint_l2norm/model_best.pth.tar'
Traceback (most recent call last):
File "main.py", line 352, in <module>
main()
File "main.py", line 145, in main
tnet.load_state_dict(checkpoint['state_dict'])
File "/scratch3/venv/fashcomp/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1406, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for Tripletnet:
Missing key(s) in state_dict: "embeddingnet.embeddingnet.layer_attend1.0.weight", "embeddingnet.embeddingnet.layer_attend1.0.bias".
So, how can you implement small architectural changes without retraining from the scratch every time?