What about using the load_state_dict on a modified version of it?
def adjust_learning_rate(optimizer, iter, each):
# sets the learning rate to the initial LR decayed by 0.1 every 'each' iterations
lr = args.lr * (0.1 ** (iter // each))
state_dict = optimizer.state_dict()
for param_group in state_dict['param_groups']:
param_group['lr'] = lr
optimizer.load_state_dict(state_dict)
return lr