Delete parameter group from optimizer

Hello all,

I need to delete a parameter group from my optimizer.
Here it is a sample code to show what I am doing to tackle the problem:

lstm = torch.nn.LSTM(3,10)
optim = torch.optim.Adam(lstm.parameters())

# train a bit and then delete the parameters from the optimizer
# in order not to train them anymore 
del optim.param_group[0] # optim.param_group = []

# create a new LSTM and train it instead of the old one
new_lstm = torch.nn.LSTM(3,10)
optim.add_param_group(
    {'params' : [p for p in new_lstm.parameters()]}
)

Is there an official PyTorch way to delete parameters group or do I have to use del?

2 Likes

Ah, I need to do the same while training a model on a very large tree where each node has a coefficient vector and I only need to backprop over the ones traversed in a given run.

Did you see any problems with your approach?

I used my approach without experiencing any particular problem. Until now, at least :grinning:

Any new experience about this way to remove parameters from an existing optimizer?

I used this approach without any problem so far. It always gave me consistent results.

1 Like

Just curious, so no official document to delete params in optimizer?

Hey,

Any update on how to do this?

1 Like

For me, I encountered an error. Apparently, the reason is that Optimizers also store state attribute as shown in this line. I think this state stores things like momentum which changes per iteration.

Having said this, this is how I reset my parameters:

lstm = torch.nn.LSTM(3,10)
optim = torch.optim.Adam(lstm.parameters())

# train a bit and then delete the parameters from the optimizer
# in order not to train them anymore 
optim.param_group.clear() # optim.param_group = []
optim.state.clear() # optim.state = defaultdict(dict)

# create a new LSTM and train it instead of the old one
new_lstm = torch.nn.LSTM(3,10)
optim.add_param_group(
    {'params' : [p for p in new_lstm.parameters()]}
)
1 Like

This issue needs to be tackled by Pytorch. Currently I am resorting to a non-trivial solution like this:

def get_param_group_index(optim, param):
    for i, optim_param_group in enumerate(optim.param_groups):
        optim_param_group_list = optim_param_group["params"]
        assert len(optim_param_group_list) == 1
        optim_param = optim_param_group_list[0]
        if param.shape == optim_param.shape and (param==optim_param).all():
            return i
    # raise Exception("Could not find param in optim.param_groups")

def remove_param_from_optimizer(optim, pg_index):
    # Remove corresponding state
    for param in optim.param_groups[pg_index]['params']:
        if param in optim.state:
            del optim.state[param]
    del optim.param_groups[pg_index]