How to initialize weights when using a manual Kfold and custom nn.module (init_weights how to?)

From the below URL’s I was able to piece together that if im using a manual K-fold (Not skorch) that I need to reinitialize the weights before each new fold, this is the suggestion on how to do that,

I must be missing something though,

class Net(nn.Module):

    def init_weights(self):
        if (type(self) == nn.Linear):

    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(4, hl)
        self.fc2 = nn.Linear(hl, 3)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

net = Net()
net.apply(init_weights) # reset network weights

NameError: name ‘init_weights’ is not defined

Define init_weigths outside of your model and pass m instead of self:

def weights_init(m):
    if isinstance(m, nn.Conv2d):

1 Like

Suppose my model is a more complex architecture (for example an U-net), what is the best way to restart their data? I am trying to use k folds validation and my model is an U-net but I don’t know how to restart their weights.

I really appreciate your help @ptrblck :slight_smile:


It depends a bit on your use case.

If you haven’t written a custom weights_init method, but just initialize the model and thus use the default random initializations, I would recommend to just recreate the model.
Note that you should also recreate the optimizer in this case.

On the other hand, if you already defined a custom weights_init method, just reset the model via model.apply(weights_init).

Also, not sure if this fits your use case, but you could initialize the model once, create a copy.deepcopy of its state_dict, and reload this state_dict for each fold via model.load_state_dict(state_dict).

Let me know, if one of these approaches would work. :slight_smile:

My model use the defaults random initialization, but I think is a best practice restart the weights of my current model instance instead of create a new one. I didn’t understand what do you mean with recreate the model, could you please give me more details? :slight_smile:

Let me give you more details, I am trying to fit a UNet model with a dataset of around 1000 images. I thought to apply K folds cross validation due to the size of the dataset, but after I tested my code I noticed I didn’t restart model’s weights when a new folds configuration starts, it is the right way to handle the model in this case?

Thanks for your reply :slight_smile:

I just meant to use a new instance as shown here:

model = MyModel()
# perform training on first fold

# recreate model
model = MyModel()
# perform training on second fold ...

You could of course recreate these instances in a loop, if that’s more convenient.

Yes, you should reinitialize the model randomly after it was trained on a specific fold.
If you don’t want to recreate the model instance, you could call the reset_parameters() method on each submodule:

for name, module in model.named_modules():
    if hasattr(module, 'reset_parameters'):
        print('Resetting ', name)

Note that I would still recommend to recreate the optimizer, as it might store running estimates (e.g. if you are using Adam).

1 Like

Thank you :blush: what would happen if I handle my train and evaluation process through Ignite? Should I define the model, optimizer, trainer, and evaluator per fold?

I’m not sure, how this workflow would look in Ignite, but @vfdev-5 certainly knows! :wink:

1 Like

Personnally I would do this via a single script file or method that has as input fold index. In this case, there will be no problems with restarting model, optimizer, lr_scheduler and other state-full object (e.g. amp if nvidia/apex is used).
So, a basic code will be something like that:

from ignite.engine import create_supervised_trainer, create_supervised_evaluator, Events
from ignite.metrics import Accuracy

def train_on_single_fold(*args, fold_index=0, **kwargs):

    train_loader, val_loader = get_dataflow(fold_index=fold_index, **kwargs)
    model, optimizer, criterion, lr_scheduler = initialize_model(**kwargs)
    metrics = {
         "accuracy": Accuracy(),

    trainer = create_supervised_trainer(model, optimizer, criterion)

    def update_lr_scheduler(_):

    evaluator = create_supervised_evaluator(model, metrics=metrics)

    def validate(trainer):
        metrics = evaluator.state.metrics
        print("After {} iterations, binary accuracy = {:.2f}"
              .format(trainer.state.iteration, metrics['accuracy'])), max_epochs=kwargs.get(max_epochs, 100))

Hope this helps. Please, feel free to ask more about ignite usage if needed (and do not forget to set the category “ignite” otherwise our team won’t get the notification).

PS: @ptrblck thanks a lot for mentioning !

1 Like