How to save and load lr_scheduler stats in pytorch?

I’m using lr_scheduler for decreasing the learning rate . In order to be able to resume my training I need to restore the schedulers stats. But I have no idea how to do it . I have done :

...
 optimizer = torch.optim.Adadelta(net.parameters(), lr=0.1, rho=0.9, eps=1e-3, # momentum=state['momentum'],
                                     weight_decay=0.001)
 milestones = [5, 10, 15, 20,40, 50]
 scheduler = lr_scheduler.MultiStepLR(optimizer, milestones, gamma=0.1)

  if use_cuda:
    net.cuda()
    criterion.cuda()

  recorder = RecorderMeter(epochs)
  # optionally resume from a checkpoint
  if resume:
    if os.path.isfile(resume):
      print_log("=> loading checkpoint '{}'".format(resume), log)
      checkpoint = torch.load(resume)
      recorder = checkpoint['recorder']
      start_epoch = checkpoint['epoch']
      scheduler = checkpoint['scheduler']
      net.load_state_dict(checkpoint['state_dict'])
      optimizer.load_state_dict(checkpoint['optimizer'])
      print_log("=> loaded checkpoint '{}' (epoch {})" .format(resume, checkpoint['epoch']), log)
    else:
      print_log("=> no checkpoint found at '{}'".format(resume), log)
  else:
    print_log("=> did not use any checkpoint for {} model".format(arch), log)

which is saved like this :

  save_checkpoint({
      'epoch': epoch + 1,
      'arch': arch,
      'state_dict': net.state_dict(),
      'recorder': recorder,
      'optimizer' : optimizer.state_dict(),
       'scheduler': scheduler,
    }, is_best, save_path, 'checkpoint_{0}.pth.tar'.format(time_stamp), time_stamp)

and save_checkpoint itself is defined :

def save_checkpoint(state, is_best, save_path, filename, timestamp=''):
  filename = os.path.join(save_path, filename)
  torch.save(state, filename)
  if is_best:
    bestname = os.path.join(save_path, 'model_best_{0}.pth.tar'.format(timestamp))
    shutil.copyfile(filename, bestname)

Yet this does not work. can anyone help me in this regard please?
Thanks a lot in advance

3 Likes

Try to save the scheduler.state_dict() and the last_epoch should be restored.

3 Likes

Thank you very much, I did save it the way you mentioned, but for resuming I used :
scheduler.load_state_dict(checkpoint['scheduler'])
and thats all, no need to change anything else.
So basically this is how the resume section is defined now :

 # optionally resume from a checkpoint
  if resume:
    if os.path.isfile(resume):
      print_log("=> loading checkpoint '{}'".format(resume), log)
      checkpoint = torch.load(resume)
      recorder = checkpoint['recorder']
      start_epoch = checkpoint['epoch']
      scheduler.load_state_dict(checkpoint['scheduler'])
      net.load_state_dict(checkpoint['state_dict'])
      optimizer.load_state_dict(checkpoint['optimizer'])
      print_log("=> loaded checkpoint '{}' (epoch {})" .format(resume, checkpoint['epoch']), log)
    else:
      print_log("=> no checkpoint found at '{}'".format(resume), log)
  else:
    print_log("=> did not use any checkpoint for {} model".format(arch), log)
16 Likes

hey, I have similar problem with torch.save when using this scheduler. However, this code is from open source which means, it should work. After creating path, nothing is saved by torch.save. It is an empty folder there. I checked the code many times. Everything is same as original code. Could you please help with that? Thank u.

dir_img = ‘./imgs/’
dir_mask = ‘./masks/’
dir_checkpoint = ‘checkpoints/’

def train_net(net,
device,
epochs=2,
batch_size=1,
lr=0.001,
val_percent=0.1,
save_cp=True,
img_scale=0.5):

dataset = BasicDataset(dir_img, dir_mask, img_scale)
n_val = int(len(dataset) * val_percent)
n_train = len(dataset) - n_val
train, val = random_split(dataset, [n_train, n_val])
train_loader = DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True)
val_loader = DataLoader(val, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True, drop_last=True)

writer = SummaryWriter(comment=f'LR_{lr}_BS_{batch_size}_SCALE_{img_scale}')
global_step = 0

logging.info(f'''Starting training:
    Epochs:          {epochs}
    Batch size:      {batch_size}
    Learning rate:   {lr}
    Training size:   {n_train}
    Validation size: {n_val}
    Checkpoints:     {save_cp}
    Device:          {device.type}
    Images scaling:  {img_scale}
''')

optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8, momentum=0.9)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min' if net.n_classes > 1 else 'max', patience=2)
if net.n_classes > 1:
    criterion = nn.CrossEntropyLoss()
else:
    criterion = nn.BCEWithLogitsLoss()

for epoch in range(epochs):
    net.train()

    epoch_loss = 0
    #os.mkdir(dir_checkpoint)
    #logging.info('Created checkpoint directory')
    with tqdm(total=n_train, desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar:
        for batch in train_loader:
            imgs = batch['image']
            true_masks = batch['mask']
            assert imgs.shape[1] == net.n_channels, \
                f'Network has been defined with {net.n_channels} input channels, ' \
                f'but loaded images have {imgs.shape[1]} channels. Please check that ' \
                'the images are loaded correctly.'

            imgs = imgs.to(device=device, dtype=torch.float32)
            mask_type = torch.float32 if net.n_classes == 1 else torch.long
            true_masks = true_masks.to(device=device, dtype=mask_type)

            #masks_pred = net(imgs)
            #loss = criterion(masks_pred, true_masks)
            logits,probs,masks_pred = net(imgs) #logits, probas, preds
            loss = criterion(logits, true_masks)
            epoch_loss += loss.item()
            writer.add_scalar('Loss/train', loss.item(), global_step)

            pbar.set_postfix(**{'loss (batch)': loss.item()})

            optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_value_(net.parameters(), 0.1)
            optimizer.step()

            pbar.update(imgs.shape[0])
            global_step += 1
            if global_step % (len(dataset) // (10 * batch_size)) == 0:
                for tag, value in net.named_parameters():
                    tag = tag.replace('.', '/')
                    writer.add_histogram('weights/' + tag, value.data.cpu().numpy(), global_step)
                    writer.add_histogram('grads/' + tag, value.grad.data.cpu().numpy(), global_step)
                val_score = eval_net(net, val_loader, device)
                scheduler.step(val_score)
                writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], global_step)

                if net.n_classes > 1:
                    logging.info('Validation cross entropy: {}'.format(val_score))
                    writer.add_scalar('Loss/test', val_score, global_step)
                else:
                    logging.info('Validation Dice Coeff: {}'.format(val_score))
                    writer.add_scalar('Dice/test', val_score, global_step)

                writer.add_images('images', imgs, global_step)
                if net.n_classes == 1:
                    writer.add_images('masks/true', true_masks, global_step)
                    writer.add_images('masks/pred', torch.sigmoid(logits) > 0.5, global_step)

    if save_cp:
        try:
            os.mkdir(dir_checkpoint)
            logging.info('Created checkpoint directory')
        except OSError:
            pass
        print('before torch.save')
        torch.save(net.state_dict(),
                   dir_checkpoint + f'CP_epoch{epoch + 1}.pth')
        print('after torch.save')
        logging.info(f'Checkpoint {epoch + 1} saved !')

writer.close()

def get_args():
parser = argparse.ArgumentParser(description=‘Train the UNet on images and target masks’,
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(’-e’, ‘–epochs’, metavar=‘E’, type=int, default=2,
help=‘Number of epochs’, dest=‘epochs’)
parser.add_argument(’-b’, ‘–batch-size’, metavar=‘B’, type=int, nargs=’?’, default=1,
help=‘Batch size’, dest=‘batchsize’)
parser.add_argument(’-l’, ‘–learning-rate’, metavar=‘LR’, type=float, nargs=’?’, default=0.1,
help=‘Learning rate’, dest=‘lr’)
parser.add_argument(’-f’, ‘–load’, dest=‘load’, type=str, default=False,
help=‘Load model from a .pth file’)
parser.add_argument(’-s’, ‘–scale’, dest=‘scale’, type=float, default=0.5,
help=‘Downscaling factor of the images’)
parser.add_argument(’-v’, ‘–validation’, dest=‘val’, type=float, default=10.0,
help=‘Percent of the data that is used as validation (0-100)’)

return parser.parse_args()

if name == ‘main’:
logging.basicConfig(level=logging.INFO, format=’%(levelname)s: %(message)s’)
args = get_args()
device = torch.device(‘cuda’ if torch.cuda.is_available() else ‘cpu’)
logging.info(f’Using device {device}’)

# Change here to adapt to your data
# n_channels=3 for RGB images
# n_classes is the number of probabilities you want to get per pixel
#   - For 1 class and background, use n_classes=1
#   - For 2 classes, use n_classes=1
#   - For N > 2 classes, use n_classes=N
net = RNet(n_channels=3, n_classes=1)#, bilinear=True)
logging.info(f'Network:\n'
             f'\t{net.n_channels} input channels\n'
             f'\t{net.n_classes} output channels (classes)\n')
             #f'\t{"Bilinear" if net.bilinear else "Transposed conv"} upscaling')

#if args.load:
    #net.load_state_dict(
       # torch.load(args.load, map_location=device))
    
    #logging.info(f'Model loaded from {args.load}')

net.to(device=device)
# faster convolutions, but more memory
# cudnn.benchmark = True

try:
    train_net(net=net,
              epochs=args.epochs,
              batch_size=args.batchsize,
              lr=args.lr,
              device=device,
              img_scale=args.scale,
              val_percent=args.val / 100)
except KeyboardInterrupt:
    torch.save(net.state_dict(), 'INTERRUPTED.pth')
    logging.info('Saved interrupt')
    try:
        sys.exit(0)
    except SystemExit:
        os._exit(0)

You are most likely missing the / to separate the file name from the folder.
This should work:

torch.save(net.state_dict(), dir_checkpoint + f'/CP_epoch{epoch + 1}.pth')

The current checkpoint should be stored in the current working directory using the dir_checkpoint as part of its name.

PS: You can post code by wrapping it into three backticks ```, which would make debugging easier. :wink:

Thanks for reply. Haha, yes, I agree, it’s a good hint. I still cannot save anything in the created folder. But, when I run these two lines of code separately from the train method, torch.save can save something. This ‘/’ wasn’t in the original code as well. What could be possible wrong? I am running this code on a server.

Check that you have the right write permissions in the specified folder.
I would assume you should get an error, but unsure what might go wrong.
I was able to store the model in the wrong (with the folder name in the file name) and right way using your script.

Thanks for your hint. I will keep trying. I never thought about the right to write permission. Maybe this is the problem.