Pick the n best weights

You could store multiple state_dicts and afterwards load them and create the average state_dict.
A small example for two different state_dicts is given here.