Hi,
I would like to create a learning rate warm-up phase using SequentialLR
to transition from ExponentialLR
to CosineAnnealingLR
.
Why? When going from LinearLR
to CosineAnnealingLR
, the learning rates are essentially close to the max during the warm-up phase. Here is my code for this case:
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR, ExponentialLR
from torch import optim
from torch import nn
import matplotlib.pyplot as plt
import numpy as np
max_lr = 0.01
initial_lr = max_lr / 1e4
model = nn.Linear(10,10)
opt = optim.SGD(model.parameters(), lr=max_lr, momentum=0.9)
num_epochs = 20
warmup_scheduler = LinearLR(opt,
start_factor=1/1e4,
end_factor=1,
total_iters=8)
cos_scheduler = CosineAnnealingLR(opt,
T_max = num_epochs-8,
eta_min=max_lr / 10000.)
scheduler = SequentialLR(opt,
schedulers=[warmup_scheduler, cos_scheduler],
milestones=[8])
lrs = []
for epoch in range(num_epochs):
lr = scheduler.get_last_lr()[0]
lrs.append(lr)
print(f"epoch: {epoch}: lr = {lr}")
scheduler.step()
fig = plt.figure(figsize=(8,4))
ax = fig.add_subplot(1,1,1)
ax.plot(np.arange(len(lrs)),lrs)
# ax.set_yscale('log')
fig.tight_layout()
plt.savefig('foo.png', bbox_inches='tight')
plt.show()
When replacing LinearLR
by ExponentialLR
, I have to compute gamma
.
The problem is that the second scheduler seems stuck at the final learning rate.
Here is my code:
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR, ExponentialLR
from torch import optim
from torch import nn
import matplotlib.pyplot as plt
import numpy as np
max_lr = 0.01
initial_lr = max_lr / 1e4
model = nn.Linear(10,10)
opt = optim.SGD(model.parameters(), lr=initial_lr, momentum=0.9)
num_epochs = 20
gamma = np.power(max_lr/initial_lr,1/7)
warmup_scheduler = ExponentialLR(opt,
gamma=gamma)
cos_scheduler = CosineAnnealingLR(opt,
T_max = num_epochs-8,
eta_min=max_lr / 10000.)
scheduler = SequentialLR(opt,
schedulers=[warmup_scheduler, cos_scheduler],
milestones=[8])
lrs = []
for epoch in range(num_epochs):
lr = scheduler.get_last_lr()[0]
lrs.append(lr)
print(f"epoch: {epoch}: lr = {lr}")
scheduler.step()
fig = plt.figure(figsize=(8,4))
ax = fig.add_subplot(1,1,1)
ax.plot(np.arange(len(lrs)),lrs)
# ax.set_yscale('log')
fig.tight_layout()
plt.savefig('foo.png', bbox_inches='tight')
plt.show()
Thanks for any help.