Loss becomes higher after resuming from DDP model checkpoint

Problem

Recently, I have tried to use DDP. I just followed the tutorial to change my original code for DP.
I accidentally stopped my DDP training, so I planned to resume model as what I did before. However, loss becomes much higher after resuming, I think it might be some error in my code.



Here are my codes

save model

# code for resuming after validation
if args.local_rank == 0:
    tf_writer.add_scalar('acc/test_top1_best', best_prec1, epoch)

    output_best = 'Best Prec@1: %.3f\n' % (best_prec1)
    print(output_best)
    log_training.write(output_best + '\n')
    log_training.flush()

    save_checkpoint({
        'epoch': epoch + 1,
        'arch': args.arch,
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'best_prec1': best_prec1,
    }, is_best)
def save_checkpoint(state, is_best):
    filename = '%s/%s/ckpt.pth.tar' % (args.root_model, args.store_name)
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, filename.replace('pth.tar', 'best.pth.tar'))

I don’t save scheduler because I manually set the learning rate for the optimizer. And I will set learning rate before training.

adjust_learning_rate(optimizer, epoch, args.lr_type, args.lr_steps)

resume model

if args.resume:
    if os.path.isfile(args.resume):
        checkpoint = torch.load(args.resume, map_location=torch.device('cpu'))
        pretrained_dict = checkpoint['state_dict']
        new_state_dict = OrderedDict()
        for k, v in pretrained_dict.items():
            if '.total' not in k:
                name = k[7:]  # remove 'module.'
                # name = name.replace('.net', '')
                new_state_dict[name] = v
        model.load_state_dict(new_state_dict)

        model=torch.nn.parallel.DistributedDataParallel(model,
                                                device_ids=[local_rank],
                                                output_device=local_rank)

        if 'epoch' in checkpoint.keys():
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            
            # get_optim_policies is a fuction to set optimizer policies
            optimizer = torch.optim.SGD(get_optim_policies(model), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)

            optimizer.load_state_dict(checkpoint['optimizer'])
            print(("=> loaded checkpoint '{}' (epoch {})"
                .format(args.evaluate, checkpoint['epoch'])))
            print(("=> best top1 '{}'".format(best_prec1)))

The above codes work well when I use DP. Are there anything I miss while using DDP?

When you call save_checkpoint, is the model var a DDP instance? If yes, you might need to save model.module instead? But I don’t think that’s the reason for the error jump. When you load the module, if you do not use DDP or DP (just a local model), is the loss after recovery as expected?

I might miss sth, looks like in the “resume model” part, the model state is loaded to CPU and not moved to local_rank before passing to DDP ctors? Or is the model already on the correct device before loading state dict?