Reset model weights

Sure! You just have to define your init function:

def weights_init(m):
    if isinstance(m, nn.Conv2d):
        torch.nn.init.xavier_uniform(m.weight.data)

And call it on the model with:

model.apply(weight_init)

If you want to have the same random weights for each initialization, you would need to set the seed before calling this method with:

torch.manual_seed(your_seed)
16 Likes