How do I load a new copy of my parameters to be trained in the optimizer without breaking it in a subtle way?

I am using higher and there is an unusual bug where my optimizer stops having a pointer to the real parameters I am trying to update (details here: I am trying to overcome it by giving a pointer to the new (unfortuantely deep copied) versions of my parameters so that the optimizer can optimize (in place) the right thing.

So I made this function:

    def load_new_params(self, params):
        self.param_groups = []

        param_groups = list(params)
        if len(param_groups) == 0:
            raise ValueError("optimizer got an empty parameter list")
        if not isinstance(param_groups[0], dict):
            param_groups = [{'params': param_groups}]

        for param_group in param_groups:

which I believe should work but I noticed there was a weird self.state field in optimizers which worries me that is not being tracked correctly (or something else that is subtle).

Is doing:


enough to guarantee that the right version of the parameters are updated correctly and no other unexpected bug is happening due to self.state or otherwise?