I have questions about training networks progressively in pytorch

I don’t really understand the proper way if there is one, on how to progressively add layers neural networks for something like progressive autoencoder or gan. First of all, should you create the entire network first and block access to the bigger input layers and just grab results from the inner layers first? Or do you make a small network first and when it hit some threshold we would add a new layer to the network? Finally, how do we add new layers to the network during training? I would like to experiment with a progressive autoencoder, for input its easy we can just transforms.resize() but how do you add outer network layers that takes new inputs on the fly?

Edited: Also if its possible to add new layers during training do I need to call model = model() again? Will that reset all the previous weight parameters?

The simplest way to do it is to train certain layers while have the other layers act like the identity function. You can select what parameters you want to train when making your optimizer (https://pytorch.org/docs/stable/optim.html#per-parameter-options). To speed things up, you can avoid computing gradients for the modules that you don’t train (https://pytorch.org/docs/stable/notes/autograd.html#excluding-subgraphs-from-backward).

Dynamically adding/removing modules is relatively easy in Tensorflow/Keras, since a graph of the model is available on the python side. In PyTorch, you cannot traverse the graph of your model to insert modules.

What you could do is re-instantiate a model with more layers, but then you will have trouble loading the state_dict. (It’s still doable; you can first get the state_dict of the new bigger model, and copy the available key/values from the state_dict of the smaller model. Then you can load the mutated state_dict in the new model. That’s what I myself do for replacing modules inside a network.)

Good luck!

Thank you I am currently looking at a progressive gan implementation in pytorch and just like you said they would re-instantiate their network, copy and paste back the state_dict every time when they want to grow their network. It looks easy enough to implement I might just do that first.

1 Like