Reinitializing the weights after each cross validation fold

based on this we probably need .modules() to properly recursively go into all modules in the model: