Deep copying PyTorch modules

Hi, I am new to PyTorch, and was wondering if there is an API defined for deep copying Modules? I have a function that accepts a Module as input, and trains it. Because I don’t want to keep track of object’s state, I am planning to just deep copy the trained model and save it somewhere else in a database.

Is it safe to just do copy.deepcopy(mymodel)?

If I had to deep copy a model, I would do something like this…

model_copy = type(mymodel)() # get a new instance
model_copy.load_state_dict(mymodel.state_dict()) # copy weights and stuff

I can’t see how that would help you “save it somewhere else in a database”. I assume you have read the docs on serialising models and the recommendations and warnings on that page.

12 Likes

In type(mymodel)(), are arguments also copied by default from mymodel instance? Or do we need to make sure the arguments match the mymodel instance?

1 Like

if there are arguments in the __init__ function of mymodel you must pass them too. type(mymodel)(args)

1 Like

what if I do

model_copy.load_state_dict(mymodel.state_dict()).to(device=‘cpu’) ?

All I want is to use an instance of a model as an input to another function while training. Here I do not need any gradient flow for this opperation.