Pickle nn.Module

I have a class that holds many nn.Modules and several pieces of data that aren’t modules as well: e.g.

class MyClass(object):
    def __init__(self):
        self._module1 = ...
        self._module2 = ...
        self._not_a_module1 = ...
        self._not_a_module1 = ...

I’d like to save and restore all of the data held by the class. Is it safe for me to directly pickle the class? e.g.

import pickle

c = MyClass()
with open("saved.txt", "w") as f:
    pickle.dump(c, f)

# ... later
with open("saved.txt", "r") as f:
    c = pickle.load(f)

Or is it necessary for me to call torch.save on the modules and separately save the non-module data on my own?

Related question: is directly pickling nn.Modules safe to do? Or do I have to save them using torch.save?

1 Like


torch.save() is a superset of pickle.dump(). It has the same capabilities (as it uses pickle under the hood), but adds more efficient handling of pytorch tensors.
Using torch.save() of both modules and non-modules should be the simplest solution for you.
I am not sure what will be the exact impact of using pickle.dump() on Module objects. If you do need to know, you should ping smth to know who you should ask about this. You can also check the implementation in this python file.