Reinitializing the weights after each cross validation fold

This is not a robust solution and wont work for anything except core torch.nn layers, but this works:

for layer in model.children():
   if hasattr(layer, 'reset_parameters'):
       layer.reset_parameters()
6 Likes