Hi, I am new to PyTorch, and was wondering if there is an API defined for deep copying Modules? I have a function that accepts a Module as input, and trains it. Because I don’t want to keep track of object’s state, I am planning to just deep copy the trained model and save it somewhere else in a database.
Is it safe to just do
If I had to deep copy a model, I would do something like this…
model_copy = type(mymodel)() # get a new instance
model_copy.load_state_dict(mymodel.state_dict()) # copy weights and stuff
I can’t see how that would help you “save it somewhere else in a database”. I assume you have read the docs on serialising models and the recommendations and warnings on that page.
In type(mymodel)(), are arguments also copied by default from mymodel instance? Or do we need to make sure the arguments match the mymodel instance?
if there are arguments in the
__init__ function of
mymodel you must pass them too.
what if I do
All I want is to use an instance of a model as an input to another function while training. Here I do not need any gradient flow for this opperation.