Hi , I was wondering how to store some of the model modules ('including their weight and bias ) into a global variable to copy them back after ?
You can save network parameters in a list, if that’s what you are after. Global or local that’s your design choice. Loop over the network parameters and save them. Checkout moduleList that can be of use as well.
You can also get all the parameters as a nice dictionnary with model.state_dict()
. Then use model.load_state_dict(state_dict)
to restore it.
I want some thing in this way
store all the Linear layers (modules or weights ) into a list
then loop over all model modules and if the module is linear assign the linear stored module that match the index
Ok,
In that case, I would use for mod_uniq_name, mod in model.named_modules()
to find all the Linear layers and save their weights in some structure (like a dict) using mod_uniq_name
as the key.
Then when you want to reload it, do the same iteration and if the mod_uniq_name
is present in your saved structure, load what is there.
Thank you but it is a bit different since I am suing some kind of tree so I have to save every Linear Node on Forward function so I have to use a different way to store them like:
a list that store the weights or a list that stores the entire module which one do you thin is better ?
If you don’t plan of modifying the module itself, I think saving only the weights is better.
yes but iterating that list needs break you can’t index it if I try :
for p55 in temppara:
print(p55)
print(temppara.index(p55))
input()
I get errors for the index and I need an index to know what module need to be mapped to what module or what weights should be maped to what module weights
If you have the same tree structure, you should encounter the same modules at the same time during the forward no? So you can simply .append()
and .pop()
your list.