Why is it not recommended to save the optimizer, model etc as pickable/dillable objs in PyTorch but instead get the state dicts and load them?

Why is it recommended to save the state dicts and load them instead of saving stuff with dill for example and then just getting the usable objects immediately?

I think I’ve done that without may issues and it saves users code.

But instead we are recommended to do something like:

def _load_model_and_optimizer_from_checkpoint(args: Namespace, training: bool = True) -> Namespace:
    """
    based from: https://pytorch.org/tutorials/recipes/recipes/saving_and_loading_a_general_checkpoint.html
    """
    import torch
    from torch import optim
    import torch.nn as nn
    # model = Net()
    args.model = nn.Linear()
    # optimizer = optim.SGD(args.model.parameters(), lr=0.001, momentum=0.9)
    optimizer = optim.Adam(args.model.parameters(), lr=0.001)

    # scheduler...

    checkpoint = torch.load(args.PATH)
    args.model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    args.epoch_num = checkpoint['epoch_num']
    args.loss = checkpoint['loss']

    args.model.train() if training else args.model.eval()

For example I’ve saved:

def save_for_meta_learning(args: Namespace, ckpt_filename: str = 'ckpt.pt'):
    if is_lead_worker(args.rank):
        import dill
        args.logger.save_current_plots_and_stats()
        # - ckpt
        assert uutils.xor(args.training_mode == 'epochs', args.training_mode == 'iterations')
        f: nn.Module = get_model_from_ddp(args.base_model)
        # pickle vs torch.save https://discuss.pytorch.org/t/advantages-disadvantages-of-using-pickle-module-to-save-models-vs-torch-save/79016
        args_pickable: Namespace = uutils.make_args_pickable(args)
        torch.save({'training_mode': args.training_mode,  # assert uutils.xor(args.training_mode == 'epochs', args.training_mode == 'iterations')
                    'it': args.it,
                    'epoch_num': args.epoch_num,

                    'args': args_pickable,  # some versions of this might not have args!

                    'meta_learner': args.meta_learner,
                    'meta_learner_str': str(args.meta_learner),  # added later, to make it easier to check what optimizer was used

                    'f': f,
                    'f_state_dict': f.state_dict(),  # added later, to make it easier to check what optimizer was used
                    'f_str': str(f),  # added later, to make it easier to check what optimizer was used
                    # 'f_modules': f._modules,
                    # 'f_modules_str': str(f._modules),

                    'outer_opt': args.outer_opt,  # added later, to make it easier to check what optimizer was used
                    'outer_opt_state_dict': args.outer_opt.state_dict(),  # added later, to make it easier to check what optimizer was used
                    'outer_opt_str': str(args.outer_opt)  # added later, to make it easier to check what optimizer was used
                    },
                   pickle_module=dill,
                   f=args.log_root / ckpt_filename)

then loaded:

def get_model_opt_meta_learner_to_resume_checkpoint_resnets_rfs(args: Namespace,
                                                                path2ckpt: str,
                                                                filename: str,
                                                                device: Optional[torch.device] = None
                                                                ) -> tuple[nn.Module, optim.Optimizer, MetaLearner]:
    """
    Get the model, optimizer, meta_learner to resume training from checkpoint.

    Examples:
        - see: _resume_from_checkpoint_meta_learning_for_resnets_rfs_test
    """
    import uutils
    path2ckpt: Path = Path(path2ckpt).expanduser() if isinstance(path2ckpt, str) else path2ckpt.expanduser()
    ckpt: dict = torch.load(path2ckpt / filename, map_location=torch.device('cpu'))
    # args_ckpt: Namespace = ckpt['args']
    training_mode = ckpt.get('training_mode')
    if training_mode is not None:
        assert uutils.xor(training_mode == 'epochs', training_mode == 'iterations')
        if training_mode == 'epochs':
            args.epoch_num = ckpt['epoch_num']
        else:
            args.it = ckpt['it']
    # - get meta-learner
    meta_learner: MetaLearner = ckpt['meta_learner']
    # - get model
    model: nn.Module = meta_learner.base_model
    # - get outer-opt
    outer_opt_str = ckpt.get('outer_opt_str')
    if outer_opt_str is not None:
        # use the string to create optimizer, load the state dict, etc.
        outer_opt: optim.Optimizer = get_optimizer(outer_opt_str)
        outer_opt_state_dict: dict = ckpt['outer_opt_state_dict']
        outer_opt.load_state_dict(outer_opt_state_dict)
    else:
        # this is not ideal, but since Adam has a exponentially moving average for it's adaptive learning rate,
        # hopefully this doesn't screw my checkpoint to much
        outer_opt: optim.Optimizer = optim.Adam(model.parameters(), lr=args.outer_lr)
    # - device setup
    if device is not None:
        # if torch.cuda.is_available():
        #     meta_learner.base_model = meta_learner.base_model.cuda()
        meta_learner.base_model = meta_learner.base_model.to(device)
    return model, outer_opt, meta_learner

without issues.


Related:

I think the main reason is this: python - How to open a pickled file with dill where the objects are not findable anymore? - Stack Overflow basically, if you save objects and refactor the code it might be that dill has issues working since it doesn’t know where the code is…perhaps making your checkpoint unusable.