I need to delete a parameter group from my optimizer.
Here it is a sample code to show what I am doing to tackle the problem:
lstm = torch.nn.LSTM(3,10)
optim = torch.optim.Adam(lstm.parameters())
# train a bit and then delete the parameters from the optimizer
# in order not to train them anymore
del optim.param_group[0] # optim.param_group = []
# create a new LSTM and train it instead of the old one
new_lstm = torch.nn.LSTM(3,10)
optim.add_param_group(
{'params' : [p for p in new_lstm.parameters()]}
)
Is there an official PyTorch way to delete parameters group or do I have to use del?
Ah, I need to do the same while training a model on a very large tree where each node has a coefficient vector and I only need to backprop over the ones traversed in a given run.
For me, I encountered an error. Apparently, the reason is that Optimizers also store state attribute as shown in this line. I think this state stores things like momentum which changes per iteration.
Having said this, this is how I reset my parameters:
lstm = torch.nn.LSTM(3,10)
optim = torch.optim.Adam(lstm.parameters())
# train a bit and then delete the parameters from the optimizer
# in order not to train them anymore
optim.param_group.clear() # optim.param_group = []
optim.state.clear() # optim.state = defaultdict(dict)
# create a new LSTM and train it instead of the old one
new_lstm = torch.nn.LSTM(3,10)
optim.add_param_group(
{'params' : [p for p in new_lstm.parameters()]}
)
This issue needs to be tackled by Pytorch. Currently I am resorting to a non-trivial solution like this:
def get_param_group_index(optim, param):
for i, optim_param_group in enumerate(optim.param_groups):
optim_param_group_list = optim_param_group["params"]
assert len(optim_param_group_list) == 1
optim_param = optim_param_group_list[0]
if param.shape == optim_param.shape and (param==optim_param).all():
return i
# raise Exception("Could not find param in optim.param_groups")
def remove_param_from_optimizer(optim, pg_index):
# Remove corresponding state
for param in optim.param_groups[pg_index]['params']:
if param in optim.state:
del optim.state[param]
del optim.param_groups[pg_index]