Load from a checkpoint?

Hello, I’m trying to run a code from github.
Here is the bash command :

%cd /content/drive/'My Drive'/WS_DAN_PyTorch-master
!python train_bap.py train\
    --model-name inception \
    --batch-size 12 \
    --dataset car \
    --image-size 224 \
    --input-size 224 \
    --checkpoint-path /content/drive/'My Drive'/WS_DAN_PyTorch-master/checkpoint/car \
    --optim sgd \
    --scheduler step \
    --lr 0.001 \
    --momentum 0.9 \
    --weight-decay 1e-5 \
    --workers 4 \
    --parts 32 \
    --epochs 80 \
    --use-gpu \
    --multi-gpu \
    --gpu-ids 0 \

as you see , I have defined the checkpoint path and yesterday I trained the model for one epoch and now I have the best model stored in the specified folder.When I checked the train function, I can’t find the load command net.load_state_dict
However, in the test function , the method is there.
Where should I put it in order not to mess with the code ? at the beginning?
Thanks for you help.


def train():
    # input params
    set_seed(GLOBAL_SEED)
    config = getConfig()
    data_config = getDatasetConfig(config.dataset)
    sw_log = 'logs/%s' % config.dataset
    sw = SummaryWriter(log_dir=sw_log)
    best_prec1 = 0.
    rate = 0.875

    # define train_dataset and loader
    transform_train = transforms.Compose([
        transforms.Resize((int(config.input_size//rate), int(config.input_size//rate))),
        transforms.RandomCrop((config.input_size,config.input_size)),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(brightness=32./255.,saturation=0.5),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])
    train_dataset = CustomDataset(
        data_config['train'], data_config['train_root'], transform=transform_train)
    train_loader = DataLoader(
            , batch_size=config.batch_size, shuffle=True, num_workers=config.workers, pin_memory=True, worker_init_fn=_init_fn)

    transform_test = transforms.Compose([
        transforms.Resize((config.image_size, config.image_size)),
        transforms.CenterCrop(config.input_size),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])
    val_dataset = CustomDataset(
        data_config['val'], data_config['val_root'], transform=transform_test)
    val_loader = DataLoader(
        val_dataset, batch_size=config.batch_size, shuffle=False, num_workers=config.workers, pin_memory=True, worker_init_fn=_init_fn)
    # logging dataset info
    print('Dataset Name:{dataset_name}, Train:[{train_num}], Val:[{val_num}]'.format(
        dataset_name=config.dataset,
        train_num=len(train_dataset),
        val_num=len(val_dataset)))
    print('Batch Size:[{0}], Total:::Train Batches:[{1}],Val Batches:[{2}]'.format(
        config.batch_size, len(train_loader), len(val_loader)
    ))
    # define model
    if config.model_name == 'inception':
        net = inception_v3_bap(pretrained=True, aux_logits=False,num_parts=config.parts)
    elif config.model_name == 'resnet50':
        net = resnet50(pretrained=True,use_bap=True)

    
    in_features = net.fc_new.in_features
    new_linear = torch.nn.Linear(
        in_features=in_features, out_features=train_dataset.num_classes)
    net.fc_new = new_linear
    # feature center
    feature_len = 768 if config.model_name == 'inception' else 512
    center_dict = {'center': torch.zeros(
        train_dataset.num_classes, feature_len*config.parts)}

    # gpu config
    use_gpu = torch.cuda.is_available() and config.use_gpu
    if use_gpu:
        net = net.cuda()
        center_dict['center'] = center_dict['center'].cuda()
    gpu_ids = [int(r) for r in config.gpu_ids.split(',')]
    if use_gpu and config.multi_gpu:
        net = torch.nn.DataParallel(net, device_ids=gpu_ids)

    # define optimizer
    assert config.optim in ['sgd', 'adam'], 'optim name not found!'
    if config.optim == 'sgd':
        optimizer = torch.optim.SGD(
            net.parameters(), lr=config.lr, momentum=config.momentum, weight_decay=config.weight_decay)
    elif config.optim == 'adam':
        optimizer = torch.optim.Adam(
            net.parameters(), lr=config.lr, weight_decay=config.weight_decay)

    # define learning scheduler
    assert config.scheduler in ['plateau',
                                'step'], 'scheduler not supported!!!'
    if config.scheduler == 'plateau':
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, 'min', patience=3, factor=0.1)
    elif config.scheduler == 'step':
        scheduler = torch.optim.lr_scheduler.StepLR(
            optimizer, step_size=2, gamma=0.9)

    # define loss
    criterion = torch.nn.CrossEntropyLoss()
    if use_gpu:
        criterion = criterion.cuda()

    # train val parameters dict
    state = {'model': net, 'train_loader': train_loader,
             'val_loader': val_loader, 'criterion': criterion,
             'center': center_dict['center'], 'config': config,
             'optimizer': optimizer}
    ## train and val
    engine = Engine()
    print(config)
    for e in range(config.epochs):
        if config.scheduler == 'step':
            scheduler.step()
        lr_val = get_lr(optimizer)
        print("Start epoch %d ==========,lr=%f" % (e, lr_val))
        train_prec, train_loss = engine.train(state, e)
        prec1, val_loss = engine.validate(state)
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        save_checkpoint({
            'epoch': e + 1,
            'state_dict': net.state_dict(),
            'best_prec1': best_prec1,
            'optimizer': optimizer.state_dict(),
            'center': center_dict['center']
        }, is_best, config.checkpoint_path)
        sw.add_scalars("Accurancy", {'train': train_prec, 'val': prec1}, e)
        sw.add_scalars("Loss", {'train': train_loss, 'val': val_loss}, e)
        if config.scheduler == 'plateau':
            scheduler.step(val_loss)

def test():
    ##
    engine = Engine()
    config = getConfig()
    data_config = getDatasetConfig(config.dataset)
    # define dataset
    transform_test = transforms.Compose([
        transforms.Resize((config.image_size, config.image_size)),
        transforms.CenterCrop(config.input_size),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])
    val_dataset = CustomDataset(
        data_config['val'], data_config['val_root'], transform=transform_test)
    val_loader = DataLoader(
        val_dataset, batch_size=config.batch_size, shuffle=False, num_workers=config.workers, pin_memory=True)
    # define model
    if config.model_name == 'inception':
        net = inception_v3_bap(pretrained=True, aux_logits=False)
    elif config.model_name == 'resnet50':
        net = resnet50(pretrained=True)

    in_features = net.fc_new.in_features
    new_linear = torch.nn.Linear(
        in_features=in_features, out_features=val_dataset.num_classes)
    net.fc_new = new_linear

    # load checkpoint
    use_gpu = torch.cuda.is_available() and config.use_gpu
    if use_gpu:
        net = net.cuda()
    gpu_ids = [int(r) for r in config.gpu_ids.split(',')]
    if use_gpu and len(gpu_ids) > 1:
        net = torch.nn.DataParallel(net, device_ids=gpu_ids)
    #checkpoint_path = os.path.join(config.checkpoint_path,'model_best.pth.tar')
    net.load_state_dict(torch.load(config.checkpoint_path)['state_dict'])

    # define loss
    # define loss
    criterion = torch.nn.CrossEntropyLoss()
    if use_gpu:
        criterion = criterion.cuda()
    prec1, prec5 = engine.test(val_loader, net, criterion)

Here is the engine function :

class Engine():
    def __init__(self,):
        pass

    def train(self,state,epoch):
        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses = AverageMeter()
        top1 = AverageMeter()
        top5 = AverageMeter()
        config = state['config']
        print_freq = config.print_freq
        model = state['model']
        criterion = state['criterion']
        optimizer = state['optimizer']
        train_loader = state['train_loader']
        model.train()
        end = time.time()
        for i, (img, label) in enumerate(train_loader):
            # measure data loading time
            data_time.update(time.time() - end)

            target = label.cuda()
            input = img.cuda()
            # compute output
            attention_maps, raw_features, output1 = model(input)
            features = raw_features.reshape(raw_features.shape[0], -1)

            feature_center_loss, center_diff = calculate_pooling_center_loss(
                features, state['center'], target, alfa=config.alpha)

            # update model.centers
            state['center'][target] += center_diff

            # compute refined loss
            # img_drop = attention_drop(attention_maps,input)
            # img_crop = attention_crop(attention_maps, input)
            img_crop, img_drop = attention_crop_drop(attention_maps, input)
            _, _, output2 = model(img_drop)
            _, _, output3 = model(img_crop)

            loss1 = criterion(output1, target)
            loss2 = criterion(output2, target)
            loss3 = criterion(output3, target)

            loss = (loss1+loss2+loss3)/3 + feature_center_loss
            # measure accuracy and record loss
            prec1, prec5 = accuracy(output1, target, topk=(1, 5))
            losses.update(loss.item(), input.size(0))
            top1.update(prec1[0], input.size(0))
            top5.update(prec5[0], input.size(0))

            # compute gradient and do SGD step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % print_freq == 0:
                print('Epoch: [{0}][{1}/{2}]\t'
                    'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                    'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                    'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                    'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                    'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                        epoch, i, len(train_loader), batch_time=batch_time,
                        data_time=data_time, loss=losses, top1=top1, top5=top5))
                print("loss1,loss2,loss3,feature_center_loss", loss1.item(), loss2.item(), loss3.item(),
                    feature_center_loss.item())
        return top1.avg, losses.avg
    
    def validate(self,state):
        batch_time = AverageMeter()
        losses = AverageMeter()
        top1 = AverageMeter()
        top5 = AverageMeter()
        
        config = state['config']
        print_freq = config.print_freq
        model = state['model']
        val_loader = state['val_loader']
        criterion = state['criterion']
        # switch to evaluate mode
        model.eval()
        with torch.no_grad():
            end = time.time()
            for i, (input, target) in enumerate(val_loader):
                target = target.cuda()
                input = input.cuda()
                # forward
                attention_maps, raw_features, output1 = model(input)
                features = raw_features.reshape(raw_features.shape[0], -1)
                feature_center_loss, _ = calculate_pooling_center_loss(
                    features, state['center'], target, alfa=config.alpha)

                img_crop, img_drop = attention_crop_drop(attention_maps, input)
                # img_drop = attention_drop(attention_maps,input)
                # img_crop = attention_crop(attention_maps,input)
                _, _, output2 = model(img_drop)
                _, _, output3 = model(img_crop)
                loss1 = criterion(output1, target)
                loss2 = criterion(output2, target)
                loss3 = criterion(output3, target)
                # loss = loss1 + feature_center_loss
                loss = (loss1+loss2+loss3)/3+feature_center_loss
                # measure accuracy and record loss
                prec1, prec5 = accuracy(output1, target, topk=(1, 5))
                losses.update(loss.item(), input.size(0))
                top1.update(prec1[0], input.size(0))
                top5.update(prec5[0], input.size(0))

                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()

                if i % print_freq == 0:
                    print('Test: [{0}/{1}]\t'
                        'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                        'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                        'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                        'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                            i, len(val_loader), batch_time=batch_time, loss=losses,
                            top1=top1, top5=top5))

            print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'
                .format(top1=top1, top5=top5))

        return top1.avg, losses.avg

I tried to place it here

# train val parameters dict
    net.load_state_dict(torch.load("/content/drive/My Drive/WS_DAN_PyTorch-master/checkpoint/car/checkpoint.pth.tar"))
    state = {'model': net, 'train_loader': train_loader,
             'val_loader': val_loader, 'criterion': criterion,
             'center': center_dict['center'], 'config': config,
             'optimizer': optimizer}

but it gave a runtime error, missing weights. I tried to add strict= False, it removed the error but it started from the beginning, it ignored the checkpoint!


yesterday it was at 14%. ( the epoch takes 6hours to run)

net.load_state_dict(torch.load(config.checkpoint_path+'/checkpoint.pth.tar')['state_dict'])
state = {'model': net, 'train_loader': train_loader,
             'val_loader': val_loader, 'criterion': criterion,
             'center': center_dict['center'], 'config': config,
             'optimizer': optimizer}```
I think that solved the problem. I will keep you updated.