Save a trained (from a pre-trained) model and make inference from saved model

Hi All,

I am looking to save a pre-trained model after using transfer learning to train it on a seperate dataset. Then use the trained model to make predictions on test data. My current code is as below.

model_conv = torchvision.models.regnet_y_32gf(weights = 'RegNet_Y_32GF_Weights.IMAGENET1K_SWAG_E2E_V1')
#model_conv = torchvision.models.efficientnet_b7(weights = 'EfficientNet_B7_Weights.IMAGENET1K_V1')

for param in model_conv.parameters():
    param.requires_grad = False

# Parameters of newly constructed modules have requires_grad=True by default
num_ftrs = model_conv.fc.in_features
model_conv.fc = nn.Linear(num_ftrs, 2)

model_conv = model_conv.to(device)

criterion = nn.CrossEntropyLoss()

# Observe that only parameters of final layer are being optimized as
# opposed to before.
optimizer_conv = optim.SGD(model_conv.fc.parameters(), lr=0.001, momentum=0.9)

# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1)

model_conv = train_model(model_conv, criterion, optimizer_conv,
                         exp_lr_scheduler, num_epochs=1)

# Specify a path
PATH = "state_dict_model.pt"

# Save
torch.save(model_conv.state_dict(), PATH)

# Load
model = model_conv()
model.load_state_dict(torch.load(PATH))
model.eval()

Error for the above code:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/tmp/ipykernel_27/3952178397.py in <module>
      6 
      7 # Load
----> 8 model = model_conv()
      9 model.load_state_dict(torch.load(PATH))
     10 model.eval()

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1188         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1189                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190             return forward_call(*input, **kwargs)
   1191         # Do not call functions when jit is used
   1192         full_backward_hooks, non_full_backward_hooks = [], []

TypeError: forward() missing 1 required positional argument: 'x'

Would anyone be able to please help me in this regards.

Thanks & Best Regards
AMJS

model_conv is already an instance of torchvision.models.regnet_y_32gf(weights = 'RegNet_Y_32GF_Weights.IMAGENET1K_SWAG_E2E_V1') to calling it would execute the forward pass.
In your code snippet it seems you want to initialize the model variable with model_conv, which is wrong and you can directly use the model_conv object instead.