Saving and loading a model in Pytorch?

Hi, I’m trying to implement training with check points using the above ideas, so that I could resume training from say, Epoch k and re-train the model from Epoch k to N. Suppose I’ve saved the following into the model file and reloaded in resume training: epoch, model’s state_dict(), optimizer, but I’m not seen similar training results between the two ways:

  1. train the model from Epoch 1 to N.
  2. train the model from Epoch1 to k, save the model, and resume training starting from Epoch k to N.

I checked the learning rates to be consistent between 1) and 2), using SGD with the same momentum and weight decaying rates.

Any ideas where I should be looking into?
Thanks!

6 Likes

@galactica147
can you reproduce this in a small dummy / toy script? I’m interested in looking into it if you can repro it.

Hi @smth, thanks for your reply! It will be hard to reproduce it fully, but i’ll try my best to provide as many details as possible.

In my task of object detection in 3D images, I have a U-shape resnet architecture with Faster R-CNN framework for object detection, i.e., finding objects (objectiveness score/probability) and their locations (coordinates of the bounding box).

So in training, i have something like this:

for i, (data, target) in enumerate(data_loader):                                                                                                                                                                                                                                                                                                 
    data = Variable(data.cuda(async=True))                                                                                                                                                                   
    target = Variable(target.cuda(async=True))                                                                                                                                                                                                                                                                                                                              
                                                                                                                                                                                                             
    out_dict = net(data)                                                                                                                                                                        
    loss_output = loss(out_dict['predictions'], target, train=True)                                                                                                                                        
    optimizer.zero_grad()                                                                                                                                                                               
    loss_output.backward()                                                                                                                                                                                
    optimizer.step()                                                                                                                                                                                    
                                                                                                                                                   
if epoch % save_freq == 0:                                                                                                                                                                              
    save_model(net, optimizer, os.path.join(save_dir, '%03d.ckpt' % self.epoch))  

The following is saved in the model file:

 def save_model(net, optim, ckpt_fname):                                                                                                                                                             
        state_dict = net.module.state_dict()                                                                                                                                                                         
        for key in state_dict.keys():                                                                                                                                                                                
            state_dict[key] = state_dict[key].cpu()                                                                                                                                                                  
                                                                                                                                                                                                                     
        torch.save({                                                                                                                                                                                                 
            'epoch': epoch,                                                                                                                                                                                     
            'state_dict': state_dict,                                                                                                                                                                                
            'optimizer': optim},                                                                                                                                                                                     
            ckpt_fname)

Then in resume training, I load in the pre-trained model file, and use that to resume to the training stage of say, Epoch k. Before the training actually happens, I check if this is regular or resume training:

        if pretrained is not None:                                                                                                                                                                                  
            state_dict = torch.load(pretrained)                                                                                                                                                                     
            new_state_dict = OrderedDict()                                                                                                                                                                           
            for k, value in state_dict['state_dict'].iteritems():                                                                                                                                                    
                key = "module.{}".format(k)                                                                                                                                                                                                                                                                                                           
                new_state_dict[key] = value                                                                                                                                                                          
            net.load_state_dict(new_state_dict)                                                                                                                                                                
            epoch = state_dict['epoch']                                                                                                                                                                         
            print "pre-trained epoch number: {}".format(epoch)                                                                                                                                                  
            optimizer = state_dict['optimizer']                                                                                                                                                                 
        else:                                                                                                                                                                                                        
            optimizer = SGD(                                                                                                                                                                                    
                net.parameters(),                                                                                                                                                                              
                learning_rate,                                                                                                                                                                                       
                momentum=momentum,                                                                                                                                                                                   
                weight_decay=weight_decay)

For the learning rate, I have stair-wise decaying, so to train for 100 epochs, I decay the learning rate to 1/10 at Epoch 50, and then another 1/10 at Epoch 80. In this example, k=5, so the learning rate should be the same for the two comparison runs above (baseline 1 and resume training 2).

The result is in baseline 1, I saw gradually decreasing false negative rates. While for the 2nd case, I didn’t see such progress during resume training. Also, the loss values were significantly different. I’m aware of the non-deterministic nature of GPU training, but that should not be the entire reason of such discrepancy.
I ran these two multiple times and got the same distinct results each time.

Did i miss anything important here?
Thank you!

4 Likes

i read through your explanation, but if you can reproduce such behaviour even through modifying our mnist example: https://github.com/pytorch/examples/blob/master/mnist/main.py
then it would be much more helpful.

If i get a script that i run and look through, i can debug with it.

Did you ever solve this. I’m having a similar problem with an assignment due saturday.
I trained for a week and my computer crashed, but I tried restoring my model (saved the entire model not just state dict) on another computer so I could train the last little bit. I’ve ensured that all my optimizer values(using vanilla sgd)/learning rates are the exact same; however, I get very pronounced gradient explosion immediately that prevents me from training anything further. The only difference is that I’m not on the same gpu. I even tried to move my model to cpu first then calling .cuda() after to move it back to the gpu.

@smth HYG: https://github.com/pytorch/pytorch/issues/4333#issuecomment-353890314

I have the same problem. Did you find any solution?

Instead of save_state_dict shouldn’t it be torch.save(model.state_dict(), 'mytraining.pt')? (The former wouldn’t work for me, getting errors that such a method would not exist)

2 Likes

Yes, you are right. Your approach works fine. The recommended way of saving and loading is described here.

4 Likes

Why is a deep copy needed?

Thanks @rasbt torch.save(model.state_dict(), 'saved_model_state.pt') worked for me

1 Like

huh I’m confused, is torch.save(model, ...) actually wrong and should we be using torch.save(model.state_dict(), ...) instead?

I’m trying to convert a pytorch machine learning model (.pth) into cafe or onnx yet I’m confused how to utilize (.pth) properly. I’m left with a text and image encoder in (.pth) trained to synthesis images through text messages. I tried to insert text_encoder.state_dict() directly but I keep getting errors.

Export the model

torch_out_image = torch.onnx._export(text_encoder.state_dict(), x, “super_resolution.onnx”, verbose=True)

Does anyone know the proper format from a text encoder onto a onnx converter?

Do you mean after reloading the model for continueing the training I also have to call .eval once? Or did you mean this just for the evaluation/inferencing? It’s not really clear to me.

This is how I do it:

torch.save(net.state_dict(),model_save_path + '_.pth')

save_checkpoint({
          'epoch': epoch + 1,
          # 'arch': args.arch,
          'state_dict': net.state_dict(),
          'optimizer': optimizer.state_dict(),
        }, is_best, mPath ,  str(val_acc) + '_' + str(val_los) + "_" + str(epoch) + '_checkpoint.pth.tar')

Where:

def save_checkpoint(state, is_best, save_path, filename):
  filename = os.path.join(save_path, filename)
  torch.save(state, filename)
  if is_best:
    bestname = os.path.join(save_path, 'model_best.pth.tar')
    shutil.copyfile(filename, bestname)
1 Like

huh I’m confused, is torch.save(model, …) actually wrong and should we be using torch.save(model.state_dict(), …) instead?

No, not wrong, just a different approach. I.e., via the former, the whole object gets pickled, and via the latter, only its parameters get pickled. Since pickle can be quite of a mess when it comes to import dependencies, I would generally recommended the latter approach. Esp. if you are planning to run the model on a different machine.

1 Like

I find that my model accuracy drops a little bit when i load a saved checkpoint compared to before i saved the state.

Here is the dict im saving using torch.save(…)

save_checkpoint({
          'epoch': cur_epoch,
          'state_dict': model.state_dict(),
          'best_prec': best_prec,
          'loss_train': loss_train,  
          'optimizer': optimizer.state_dict(),
        }, is_best, OUT_DIR, 'acc-{:.4f}_loss-{:.4f}_epoch-{}_checkpoint.pth.tar'.format(val_acc, val_loss, cur_epoch))

And here is how i load a saved state:

def load_checkpoint(checkpoint, model, optimizer):
    """ loads state into model and optimizer and returns:
        epoch, best_precision, loss_train[]
    """
    if os.path.isfile(load_path):
        print("=> loading checkpoint '{}'".format(load_path))
        checkpoint = torch.load(load_path)
        epoch = checkpoint['epoch']
        best_prec = checkpoint['best_prec']
        loss_train = checkpoint['loss_train']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("=> loaded checkpoint '{}' (epoch {})"
              .format(epoch, checkpoint['epoch']))
        return epoch, best_prec, loss_train
    else:
        print("=> no checkpoint found at '{}'".format(load_path))
        # epoch, best_precision, loss_train
        return 1, 0, []

Can anybody spot what i am doing wrong?

1 Like

Should I care about the mode of model when saving the model via state_dict method?

Are you able to reproduce the results using this way? Don’t we need to store the optimizer states as well?

Don’t we need to store the optimizer states as well?