Loading model weights and some variables from file

Newbie question here. I saved my model using
torch.save(dual_encoder.state_dict(), ‘SAVED_MODEL.pt’) in the file in which I train my model.

Now in the file in which I want to evaluate my model, I tried the following (based on this pytorch explanation):

import torch
import model

encoder_model = model.Encoder(
input_size = model.emb_dim,
hidden_size = 200,
vocab_size = model.vocab_len)

dual_encoder = model.DualEncoder(encoder_model)

dual_encoder.eval()
dual_encoder.training == False

dual_encoder.load_state_dict(torch.load(‘SAVED_MODEL.pt’))

I want to import the model parameters and some other variables from the model file, but what it does instead is going into training again when I try to import it…
how can I make sure it just imports stuff without starting training again?

The training state is also a part of state_dict and will thus be restored when you load you trained model.
After load_state_dict(), you should call dual_encoder.eval() or dual_encoder.training == True. (IIRC, they do the same thing)

I just tried, still the same thing happens. The problem I see is:
In order to be able to call

dual_encoder.load_state_dict(torch.load(‘SAVED_MODEL.pt’))
dual_encoder.eval()
dual_encoder.training == False (or True?)

I first need to import the model file. Running the import command alone will start the training (while not importing some variables!)…

Could you post the model.py code?
If seems that you have some training procedure code in there without protecting it with
if __name__=='__main__:

Importing a Python module calls each “global” line in the script. You should therefore move all code into functions like in the imagenet example.
The main function won’t be called when you import the module, just if you call it directly (e.g. from your terminal).

Thanks, you are right. I wrote the training in model.py as global lines :no_mouth: … I will try to move the training code into a function.

EDIT: It works now, thanks!