How to save your net in pytorch?

I think it’s the torch.save(aNet, ‘myNet’), and this doesn’t complain, but I cant seem to load it… :-/

Code from the imagenet example :

Loading:

if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.evaluate, checkpoint['epoch']))

Saving:

torch.save({
            'epoch': epoch + 1,
            'arch': args.arch,
            'state_dict': model.state_dict(),
            'best_prec1': best_prec1,
        }, 'checkpoint.tar' )
5 Likes

Does this loading recover the last updated learning rate ?

Nope. For that you need to save and restore optim.state_dict() too.

Add the following to the state you save

"optim_state": optim.state_dict(),

and load like this

optimizer.load_state_dict(checkpoint['optim_state'])