does this work for you?
This probably works too:
This is not a robust solution and wont work for anything except core torch.nn layers, but this works:
for layer in model.children():
if hasattr(layer, 'reset_parameters'):
layer.reset_parameters()
credit: Reinitializing the weights after each cross validation fold