Assuming we reload for training, don’t we need to load more data than weights_only
? (I’m aware though of arbitrary executing code and other caveats.)
model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)
checkpoint = torch.load(PATH, weights_only=True)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
model.eval()
# - or -
model.train()
Source
https://pytorch.org/tutorials/beginner/saving_loading_models.html#load