I have a long PyTorch custom pipeline, where I have been using for a year-ish. However, recently, I noticed that I called self.model.eval()
everytime I save a model to state dict. This does not seem to be at all common and I am worried if this is wrong.
Snippet of code:
def save_model_artifacts(
self,
path: str,
valid_trues: torch.Tensor,
valid_logits: torch.Tensor,
valid_preds: torch.Tensor,
valid_probs: torch.Tensor,
) -> None:
"""Save the weight for the best evaluation metric and also the OOF scores."""
# self.model.eval() # I been calling this .
torch.save(
{
"model_state_dict": self.model.state_dict(),
"optimizer_state_dict": self.optimizer.state_dict(),
"scheduler_state_dict": self.scheduler.state_dict(),
"oof_trues": valid_trues,
"oof_logits": valid_logits,
"oof_preds": valid_preds,
"oof_probs": valid_probs,
},
path,
)
def fit(...):
train
valid
save
I did it for very long and did not notice inference issues (results look ok), but Just want to be sure if this is a potential issue and how to avoid it.