so i save my the model as a checkpoint using the following code
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'train_loss': train_loss
}, SAVE_PATH + "Regressor_epoch{}_step{}.pkl".format(epoch, steps))
once i load the optimizer using optimizer.load_state_dict()
and resume training my model starts from the initial LR which is 0.001 and not 1e-9
given below is my training code
writer = SummaryWriter()
model = model.cuda()
criterion = nn.SmoothL1Loss().cuda()
#optimizer = torch.optim.Adam(model.parameters(), lr=inf2['optimizer_state_dict']['param_groups'][0]['lr'])
# optimizer.state_dict()['param_groups']['params'] = inf2['optimizer_state_dict']['param_groups']['params']
optimizer.load_state_dict(inf2['optimizer_state_dict'])
best_val_loss = inference_dict['val_loss']
best_train_loss = inference_dict['train_loss']
scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[4, 8, 15, 20, 45], gamma=0.1)
num_epochs = 150
running_loss = 0
steps = 0
print_every = 35
log_every = 10
log_step = 0
for epoch in range(num_epochs):
model.train()
scheduler.step()
for data_ in trainloader:
steps += 1
img, bbox = data_
img = Variable(img.cuda())
target = Variable(bbox.cuda())
output = model(img)
loss = criterion(output, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
running_loss += loss.item()
train_loss = running_loss/steps
writer.add_scalar('Training Loss', train_loss, steps)
writer.add_scalar('Learning rate', optimizer.state_dict()['param_groups'][0]['lr'], epoch)
if train_loss < best_train_loss:
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'train_loss': train_loss
}, SAVE_PATH + "Regressor_epoch{}_step{}.pkl".format(epoch, steps))
best_train_loss = train_loss
if steps % print_every == 0:
print("Epoch: {}/{}.. ".format(epoch+1, num_epochs),
"Training Loss: {:.4f}.. ".format(train_loss),
"Learning Rate: {}".format(optimizer.state_dict()['param_groups'][0]['lr']))
writer.close()