Packaging pytorch topology first and checkpoints later

In 2019-2020 I wrote a library similar to the current incarnation of torch.package, called torch_liberator (GitHub - Kitware/torch_liberator). The key difference is that torch-liberator relies on static analysis (via a separate library I wrote called liberator: GitHub - Kitware/liberator).

The way I used it was: a model class (or a pointer to the file that contained the class and the class name) would be passed to the export_model_code function. This would start a new empty file. First, it would extract the code for that class only, and then, based on what the user specified in “export_modules” (analogous to “intern”-ing in torch.package), it would iteratively (and statically) pull in dependencies from other modules. The result of this process was python code that could be exported.

The next thing that the user specfied were the keyword arguments that could be used to construct an instance of the exported class. If possible these were saved as Python literals, but if not the offending kwargs were pickled. This minimized the use of pickling, which I view as a desirable property. Lastly, the user specified an existing snapshot on-disk. These three pieces (1) the model code (2) the instantiation kwargs, and (3) the model state were packged into a zipfile. In production this exported module could be loaded more-or-less independently of the training code, an instance of the class could be created, and then the saved state could be loaded into the model.

This worked well, but it had a few problems. First, because this worked completely by static analysis, it was fragile to non-basic ways of using it. Second, it didn’t munge any names except for the root package name. The new torch.package module in 1.9 solves a lot of these problems, and I want to move away from torch-liberator and start using torch.package.

This leads me to my question. It seems that torch.package.PackageExporter needs to be passed a model that has already been loaded with the desired model_state, because it needs to pickle the entire model class as export time.

I was wondering if there was a way to “prepare” a basic torch package on-disk, such that it only contained the model (perhaps with some random weight initialization, although ideally this wouldn’t be required to save space). Then given this exported model topology and some existing on-disk model_state checkpoint, I would like to simply add the model state checkpoint into the zipfile (perhaps overwriting some other file in the zipfile).

The use case is that at the start of training, I would like to export the untrained topology into the training directory. Then as model-state checkpoints are saved as training progresses, I would like to be able to quickly construct a torch package corresponding to that checkpoint state without having to load the whole thing into Python memory and then pickle it out again.

Is this at all possible with the current state of torch.package? Would it be possible to modify it to accomidate this use-case? Is this use case clear and reasonable?

I’d like to clarify the use case.

At the end of training I my trainer will save the best N candidate checkpoints based on training and validation metrics. I want to create a torch package for each of these N checkpoints to create N corresonding candidate packages that can then be evaluated in standalone prediction+evaluation code.

I could do this by looping over each checkpoint, and calling model.load_state_dict(torch.load(checkpoint_path)), but this seems heavy-handed and unnecessary. The call to torch.load is expensive and calling model.load_state_dict is not thread safe, and has the side effect of modifying the current state of the model (and I would like to avoid side-effects if possible).

Is there a way to structure a torch.package, such that I can make a copy of some “base” zipfile (where ideally the purpose of that zipfile is to only store the network topology, perhaps it has some dummy file representing where the weight checkpoint would go, but it would be nice if that was kept small to minimize the time it takes to make a copy of the file) and then simply modify the copy by adding the desired checkpoint containing the pickled model_state_dict?

I’d suggest opening this issue in the torch.package torch.deploy tag