Reset the parameters of a model

Is there any method to reset the parameters of the model?
Or I have to save the state_dict of a new model and load it when I want to retrain the model?

It depends on your use case.
If you need exactly the same parameters for the new model in order to recreate some experiment, I would save and reload the state_dict as this would probably be the easiest method.

However, if you just want to train from scratch using a new model, you could just instantiate a new model, which will reset all parameters by default or use a method to initialize your parameters:

def weight_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        nn.init.xavier_uniform_(m.weight, gain=nn.init.calculate_gain('relu'))
        nn.init.zeros_(m.bias)

model.apply(weight_init)
2 Likes

does this work for you?


This probably works too:


This is not a robust solution and wont work for anything except core torch.nn layers, but this works:

for layer in model.children():
   if hasattr(layer, 'reset_parameters'):
       layer.reset_parameters()

credit: Reinitializing the weights after each cross validation fold