Question about serialization while saving models in PyTorch

Let’s have a look at the underlying code:
torch.save basically only calls torch._save

Inside this function there is a function named persistent_id defined and beside other things the return values of this function are pickled.

For torch.nn.Module this function does the following:

if isinstance(obj, type) and issubclass(obj, nn.Module):
            if obj in serialized_container_types:
                return None
            serialized_container_types[obj] = True
            source_file = source = None
            try:
                source_file = inspect.getsourcefile(obj)
                source = inspect.getsource(obj)
            except Exception:  # saving the source is optional, so we can ignore any errors
                warnings.warn("Couldn't retrieve source code for container of "
                              "type " + obj.__name__ + ". It won't be checked "
                              "for correctness upon loading.")
            return ('module', obj, source_file, source)

Which means the source code and the source file are pickled to. This results in the fact, that your source file will be parsed again during loading (and compared to the pickled source code to generate warnings if necessary). This source file is then used for model creation if the changes can be merged automatically. And thus adding new methods is valid as long as you don’t change the existing ones in a way that prevents python from merging automatically.

2 Likes