Suppose my model is a more complex architecture (for example an U-net), what is the best way to restart their data? I am trying to use k folds validation and my model is an U-net but I don’t know how to restart their weights.
If you haven’t written a custom weights_init method, but just initialize the model and thus use the default random initializations, I would recommend to just recreate the model.
Note that you should also recreate the optimizer in this case.
On the other hand, if you already defined a custom weights_init method, just reset the model via model.apply(weights_init).
Also, not sure if this fits your use case, but you could initialize the model once, create a copy.deepcopy of its state_dict, and reload this state_dict for each fold via model.load_state_dict(state_dict).
Let me know, if one of these approaches would work.
My model use the defaults random initialization, but I think is a best practice restart the weights of my current model instance instead of create a new one. I didn’t understand what do you mean with recreate the model, could you please give me more details?
Let me give you more details, I am trying to fit a UNet model with a dataset of around 1000 images. I thought to apply K folds cross validation due to the size of the dataset, but after I tested my code I noticed I didn’t restart model’s weights when a new folds configuration starts, it is the right way to handle the model in this case?
model = MyModel()
# perform training on first fold
# recreate model
model = MyModel()
# perform training on second fold ...
You could of course recreate these instances in a loop, if that’s more convenient.
Yes, you should reinitialize the model randomly after it was trained on a specific fold.
If you don’t want to recreate the model instance, you could call the reset_parameters() method on each submodule:
for name, module in model.named_modules():
if hasattr(module, 'reset_parameters'):
print('Resetting ', name)
Note that I would still recommend to recreate the optimizer, as it might store running estimates (e.g. if you are using Adam).
Personnally I would do this via a single script file or method that has as input fold index. In this case, there will be no problems with restarting model, optimizer, lr_scheduler and other state-full object (e.g. amp if nvidia/apex is used).
So, a basic code will be something like that: