Zero accuracy after loading a saved model

I’m training a torchvision’s resnet18 network on a gpu on the omniglot dataset. After the training I save the model using the following:, 'models/%s/model.pth' % model_name)

Then i try to load the model on cpu using:

model.load_state_dict(torch.load('model.pth', map_location=config.device))

When I try to validate the model I get an accuracy of 0.0 on the test dataset, even though the test accuracy during the training process was around 90%. The same happens with a different net trained on MNIST dataset. During the training process, the test accuracy was around 98%, and when the net is saved and loaded again the accuracy decreases to .Even though the number of classes are 10, and if the model was outputting random numbers, it would be correct approximately in 10% of the cases, which is not the case.
This is the code for the validation:

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def validate(val_loader, model, metric_fc):
    losses = AverageMeter()
    acc1s = AverageMeter()
    criterion = CrossEntropyLoss()
    # switch to evaluate mode

    with torch.no_grad():
        for i, (_input, target) in tqdm(enumerate(test_loader), total=len(test_loader)):
            _input =
            target = target.long().to(device)

            feature = model(_input)
            output = metric_fc(feature, target)
            loss = criterion(output, target)

            acc1, = accuracy(output, target, topk=(1,))

            losses.update(loss.item(), _input.size(0))
            acc1s.update(acc1.item(), _input.size(0))
    val_log = OrderedDict([
        ('loss', losses.avg),
        ('acc1', acc1s.avg),
    tmp = pd.Series([
        ], index=['epoch', 'lr', 'loss', 'acc1', 'val_loss', 'val_acc1'])

    return log

metric_fc = CosFace(num_features=512, num_classes=10).to(device)

model = MNISTNet(num_features=512).to(device)
model.load_state_dict(torch.load('model.pth'), strict=False)

log = validate(test_loader, model, metric_fc)


does it work with strict=True?

I notice that a network called metric_fc is created, but no state is loaded for it. Is this network not trained?

1 Like

Did you not load all the parameter weights correctly, such as the last full connection output layer?

Outputs the same accuracy

Thank you, this was the solution. I wasn’t loading the head used to train the network. But my question is, if I want to export the network model to use in Android for example, I need to export thehead model as well?

You have to if it’s python code. There are tools to export a pre-computed model for C++ (I think). @tom may help you.

1 Like

I’m not sure I’m understanding it completely right, but I would probably try to stick everything into a single model to export. This could include pre- and postprocessing. When I originally ported the maskrcnn to android, I even made a C++ extension to write the detected labels into the image and made that part of the JIT traced model. That might be over the top for your application, though.

Best regards


1 Like