I am kind of confused about the behavior of some LR-Schedulers - e.g. LinearLR and ConstantLR.
The doc [here and here] says, that they decay the learning rate. But as in the example and my own code, they actually increase it. What am I getting wrong?
import numpy as np
import torch
import matplotlib.pyplot as plt
np.set_printoptions(precision=5, suppress=True)
model = torch.nn.Linear(2, 1)
lr = 0.001
epochs = 100
def run_experiment(optimizer, epochs=epochs):
lrs = []
for i in range(epochs):
optimizer.step()
lrs.append(optimizer.param_groups[0]["lr"])
scheduler.step()
lrs = np.array(lrs)
plt.plot(range(epochs),lrs)
plt.scatter(range(epochs), lrs)
print(lrs)
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0.5, total_iters=epochs)
run_experiment(optimizer)
# Outputs: [0.0005 0.00051 0.00051 0.00052 ..., 0.001]
Regards
Arash
You have used the argument start_factor
, as per the docs here start_factor
is the number we multiply our learning rate with towards the start of epoch, so each time we multiply, it increases. To decay the learning rate continuously, you will have you use end_factor
and set start_factor
as 1, which will lead to a smaller result after multiplication.
np.set_printoptions(precision=5, suppress=True)
model = torch.nn.Linear(2, 1)
lr = 0.001
epochs = 100
def run_experiment(optimizer, epochs=epochs):
lrs = []
for i in range(epochs):
optimizer.step()
lrs.append(optimizer.param_groups[0]["lr"])
scheduler.step()
lrs = np.array(lrs)
print(lrs)
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1, end_factor=0.1, total_iters=epochs)
run_experiment(optimizer)
Which gives the result:
[0.001 0.00099 0.00098 0.00097 0.00096 0.00096 0.00095 0.00094 0.00093
0.00092 0.00091 0.0009 0.00089 0.00088 0.00087 0.00086 0.00086 0.00085
0.00084 0.00083 0.00082 0.00081 0.0008 0.00079 0.00078 0.00077 0.00077
0.00076 0.00075 0.00074 0.00073 0.00072 0.00071 0.0007 0.00069 0.00068
0.00068 0.00067 0.00066 0.00065 0.00064 0.00063 0.00062 0.00061 0.0006
0.00059 0.00059 0.00058 0.00057 0.00056 0.00055 0.00054 0.00053 0.00052
0.00051 0.0005 0.0005 0.00049 0.00048 0.00047 0.00046 0.00045 0.00044
0.00043 0.00042 0.00041 0.00041 0.0004 0.00039 0.00038 0.00037 0.00036
0.00035 0.00034 0.00033 0.00032 0.00032 0.00031 0.0003 0.00029 0.00028
0.00027 0.00026 0.00025 0.00024 0.00023 0.00023 0.00022 0.00021 0.0002
0.00019 0.00018 0.00017 0.00016 0.00015 0.00014 0.00014 0.00013 0.00012
0.00011]
1 Like