Hello,
I am trying to implement transfer learning in a deep neural network. I switch from Adam to L-BFGS. With respect to the latter, I wanted to ask a few questions and clarify doubts in my implementation.
- ‘max_itr’ is the maximal number of iterations per optimization step. So, after each step, wouldn’t the value of loss get updated? When I print the loss after carrying out optimizer.step(closure), I get the same value over and over again, after multiple steps.
- How do I access this iteration variable inside the closure function to plot losses and results?
My code looks something like this:
optimizer1 = torch.optim.Adam(model.parameters(), lr=0.0005, betas=(0.9, 0.999),
eps=1e-08, weight_decay=0, amsgrad=False)
optimizer2 = torch.optim.LBFGS(model.parameters(), lr=0.0005,
max_iter=20, max_eval=None,
tolerance_grad=1e-07, tolerance_change=1e-09,
history_size=100, line_search_fn=None)
def closure():
optimizer2.zero_grad()
loss_ = Loss(..some parameters..)
loss_.backward(retain_graph=True)
print("----------- Epoch LBFGS: " + str(loss_.item()) + "----------------", flush=True)
return loss_
start_time = time.time()
loss_plot = []
print_epoch = 0
for t in range(1,no_epochs):
optimizer1.zero_grad()
loss = Loss(...some parameters.....)
loss.backward(retain_graph=True)
print("\n----------- Epoch " + str(t) + ": " + str(loss.item()) + "----------------\n", flush=True)
loss_plot.append(loss.item())
optimizer1.step()
my_loss_plot(loss_plot, path + '/log_loss_per_epoch.jpg')
if t % 10 == 0:
my_value_plot(...some parameters...)
if t % 20 == 0:
my_value_plotMACROS31(...some paramteres...)
my_f_plot31(...some parameters...)
if loss.item() < 1e-4:
print_epoch = t
break
print("\n\nSWITCHING OPTIMIZER: Adam -> L-BFGS\n\n", flush=True)
for t in range(print_epoch,1000):
optimizer2.step(closure)
print("\n----------- Epoch " + str(t) + ": " + str(loss.item()) + "----------------\n", flush=True)
loss_plot.append(loss.item())
my_loss_plot(loss_plot, path + '/log_loss_per_epoch.jpg')
if t % 40 == 0:
my_value_plot(...some paramteres...)
if t % 45 == 0:
my_value_plotMACROS31(...some paramteres...)
my_f_plot31(...some paramteres...)
if loss.item() < 1e-9: # I am not getting updated values here!
break
print("\n--- %s seconds ---" % (time.time() - start_time))