I am doing:
torch.save(reconstruct_model.state_dict(), path.join(wandb.run.dir, 'best-model.pt'))
which saves my model parameters. Separately, I’m doing:
open(path.join(wandb.run.dir, 'args.json'), 'w').write(json.dumps(vars(args)))
I’m wondering if I can just save those args
in the best-model.pt
file? Subsequently, when I load, currently, I do:
spec = importlib.util.spec_from_file_location(
"ReconstructionModel", saved_reconstruction_model_filepath)
model_spec = importlib.util.module_from_spec(spec)
spec.loader.exec_module(model_spec)
model = model_spec.ReconstructionModel(kernel_size=masked_args['kernel_size'], kernel_size_step=masked_args['kernel_size_step'])
state_dict = torch.load(model_path, map_location=device)
model.load_state_dict(state_dict)
where the masked_args
are read from the args.json
file.