Torch.load() error in Jupyter when saving in Python

Hi,

When I save a model in Python using torch.save(model, PATH) and then load it with torch.load() in Jupyter, it has the following error (I know that using torch.save(model.state_dict(), PATH) is better, but since my model architecture may change in the future, I prefer to save the whole model.):

TypeError                                 Traceback (most recent call last)
<ipython-input-16-88409a4cae12> in <module>()
     42 
---> 44 net = torch.load(filepath)

/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/torch/serialization.py in load(f, map_location, pickle_module)
    265         f = open(f, 'rb')
    266     try:
--> 267         return _load(f, map_location, pickle_module)
    268     finally:
    269         if new_fd:

/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/torch/serialization.py in _load(f, map_location, pickle_module)
    418     unpickler = pickle_module.Unpickler(f)
    419     unpickler.persistent_load = persistent_load
--> 420     result = unpickler.load()
    421 
    422     deserialized_storage_keys = pickle_module.load(f)

/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/torch/serialization.py in persistent_load(saved_id)
    381             # Ignore containers that don't have any sources saved
    382             if all(data[1:]):
--> 383                 _check_container_source(*data)
    384             return data[0]
    385         elif typename == 'storage':

/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/torch/serialization.py in _check_container_source(container_type, source_file, original_source)
    291 
    292     def _check_container_source(container_type, source_file, original_source):
--> 293         current_source = inspect.getsource(container_type)
    294         if original_source != current_source:
    295             if container_type.dump_patches:

/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/inspect.py in getsource(object)
    942     or code object.  The source code is returned as a single string.  An
    943     OSError is raised if the source code cannot be retrieved."""
--> 944     lines, lnum = getsourcelines(object)
    945     return ''.join(lines)
    946 

/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/inspect.py in getsourcelines(object)
    929     raised if the source code cannot be retrieved."""
    930     object = unwrap(object)
--> 931     lines, lnum = findsource(object)
    932 
    933     if ismodule(object):

/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/inspect.py in findsource(object)
    742     is raised if the source code cannot be retrieved."""
    743 
--> 744     file = getsourcefile(object)
    745     if file:
    746         # Invalidate cache if needed.

/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/inspect.py in getsourcefile(object)
    658     Return None if no way can be identified to get the source.
    659     """
--> 660     filename = getfile(object)
    661     all_bytecode_suffixes = importlib.machinery.DEBUG_BYTECODE_SUFFIXES[:]
    662     all_bytecode_suffixes += importlib.machinery.OPTIMIZED_BYTECODE_SUFFIXES[:]

/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/inspect.py in getfile(object)
    611             if hasattr(object, '__file__'):
    612                 return object.__file__
--> 613         raise TypeError('{!r} is a built-in class'.format(object))
    614     if ismethod(object):
    615         object = object.__func__

TypeError: <module '__main__'> is a built-in class

Is there remedy for this problem?

Does this happen for every model you save in Jupyter?

Yes, this happens for every model I save in Python3.5 then load in Jupyter.

What happens if you save a model in Jupyter, and then try to load it in Jupyter?

That will work fine. The thing is, I want to perform batch jobs with Python, then load it in Jupyter to analyze.

torch.save/torch.load are only guaranteed to work if the pickle module that is being used is the same for both. Are you on the same version of python for both Jupyter (loading) and Python (saving)?

Yes, they are using the same version and core.

Okay, I will try to reproduce this and get back to you.

1 Like

Can’t seem to reproduce:
image
(I imported main because that’s where I defined Net. As a sanity check, did you import the module containing your network into jupyter?)