Is there any method to reset the parameters of the model?
Or I have to save the state_dict of a new model and load it when I want to retrain the model?
It depends on your use case.
If you need exactly the same parameters for the new model in order to recreate some experiment, I would save and reload the state_dict
as this would probably be the easiest method.
However, if you just want to train from scratch using a new model, you could just instantiate a new model, which will reset all parameters by default or use a method to initialize your parameters:
def weight_init(m):
if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
nn.init.xavier_uniform_(m.weight, gain=nn.init.calculate_gain('relu'))
nn.init.zeros_(m.bias)
model.apply(weight_init)
2 Likes
does this work for you?
This probably works too:
This is not a robust solution and wont work for anything except core torch.nn layers, but this works:
for layer in model.children():
if hasattr(layer, 'reset_parameters'):
layer.reset_parameters()
credit: Reinitializing the weights after each cross validation fold