Question about serialization while saving models in PyTorch

Assume that I save a model using torch.save('model.pt'). After the save, I add a new function to the model class in the source code, say new_feature(self, batch).

After this, if I do model = torch.load('model.pt'), I get a few warnings about the source code change. but I can surprisingly also use the new_feature() in the loaded model (that was serialized before the new_feature was added).

In this case, does torch.load() only load the parameters from model.pt? My understanding about serialization was that the entire model object will be dumped and not just the parameters. Can some one please shed some light on this topic for me?

Much appreciated.

That depends on how you saved your model. If you saved the state_dict with torch.save(model.state_dict(), 'model.pt') only the parameters and buffers will be saved. If you save the model with torch.save(model, 'model.pt') (which is not recommended) your whole model will be pickled and saved. You may want to have a look at this guidelines.

@justusschock thank you for the prompt reply.

I did NOT save the state_dict, I saved the model directly like torch.save(model, 'model.pt'). In this case, the whole model object was pickled. So, why am I able to access the new_feature() that was added after the model.pt was serialized and dumped on the disk?

Let’s have a look at the underlying code:
torch.save basically only calls torch._save

Inside this function there is a function named persistent_id defined and beside other things the return values of this function are pickled.

For torch.nn.Module this function does the following:

if isinstance(obj, type) and issubclass(obj, nn.Module):
            if obj in serialized_container_types:
                return None
            serialized_container_types[obj] = True
            source_file = source = None
            try:
                source_file = inspect.getsourcefile(obj)
                source = inspect.getsource(obj)
            except Exception:  # saving the source is optional, so we can ignore any errors
                warnings.warn("Couldn't retrieve source code for container of "
                              "type " + obj.__name__ + ". It won't be checked "
                              "for correctness upon loading.")
            return ('module', obj, source_file, source)

Which means the source code and the source file are pickled to. This results in the fact, that your source file will be parsed again during loading (and compared to the pickled source code to generate warnings if necessary). This source file is then used for model creation if the changes can be merged automatically. And thus adding new methods is valid as long as you don’t change the existing ones in a way that prevents python from merging automatically.

2 Likes

This is awesome! Thanks for digging into the source code. :smiley: