Hi All,
I am looking to save a pre-trained model after using transfer learning to train it on a seperate dataset. Then use the trained model to make predictions on test data. My current code is as below.
model_conv = torchvision.models.regnet_y_32gf(weights = 'RegNet_Y_32GF_Weights.IMAGENET1K_SWAG_E2E_V1')
#model_conv = torchvision.models.efficientnet_b7(weights = 'EfficientNet_B7_Weights.IMAGENET1K_V1')
for param in model_conv.parameters():
param.requires_grad = False
# Parameters of newly constructed modules have requires_grad=True by default
num_ftrs = model_conv.fc.in_features
model_conv.fc = nn.Linear(num_ftrs, 2)
model_conv = model_conv.to(device)
criterion = nn.CrossEntropyLoss()
# Observe that only parameters of final layer are being optimized as
# opposed to before.
optimizer_conv = optim.SGD(model_conv.fc.parameters(), lr=0.001, momentum=0.9)
# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1)
model_conv = train_model(model_conv, criterion, optimizer_conv,
exp_lr_scheduler, num_epochs=1)
# Specify a path
PATH = "state_dict_model.pt"
# Save
torch.save(model_conv.state_dict(), PATH)
# Load
model = model_conv()
model.load_state_dict(torch.load(PATH))
model.eval()
Error for the above code:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
/tmp/ipykernel_27/3952178397.py in <module>
6
7 # Load
----> 8 model = model_conv()
9 model.load_state_dict(torch.load(PATH))
10 model.eval()
/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1188 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1189 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190 return forward_call(*input, **kwargs)
1191 # Do not call functions when jit is used
1192 full_backward_hooks, non_full_backward_hooks = [], []
TypeError: forward() missing 1 required positional argument: 'x'
Would anyone be able to please help me in this regards.
Thanks & Best Regards
AMJS