Why libtorch cannot get 'running_mean' and 'running_var' of BatchNormalization2D in .pt file?

I’ve found that the result of libtorch is very different from pytorch result.
I save my model as .pt file using pytorch and load by libtorch, and all parameters are successfully copied which I double checked by loading it by pytorch again.

After checking the module.parameter() in libtorch, I’ve found that all of BatchNormalization2D modules of libtorch do not contain running_mean and running_var. And also, after the first conv layer, all of parameters are not the same as original value.

I’m looking for someone who is suffering the same problem with me, and who has checked the same issue. I think it is the one of main reasons why the results are not same between libtorch and pytorch.

I cannot reproduce the error and this dummy code using nn.BatchNorm2d successfully exports and loads the running stats:

# script.py
import torch
import torch.nn as nn
import torchvision.models as models

# Setup
model = nn.BatchNorm2d(3)
#models.resnet18()

# Perform some dummy updates for the running stats
for _ in range(10):
    out = model(torch.randn(10, 3, 224, 224) * 5 + 10.)

# Script
model = torch.jit.script(model)
model.eval()

# Save
data = torch.ones(1, 3, 224, 224)
output = model(data)
torch.save(output, 'output.pth')

model.save('resnet_scripted.pth')

print(model.running_mean)
> tensor([6.5143, 6.5129, 6.5125])
print(model.running_var)
> tensor([16.6550, 16.6079, 16.6480])
# main.cpp
#include <torch/script.h> // One-stop header.

#include <iostream>
#include <memory>

int main(int argc, const char* argv[]) {
  if (argc != 2) {
    std::cerr << "usage: example-app <path-to-exported-script-module>\n";
    return -1;
  }


  torch::jit::script::Module module;
  try {
    // Deserialize the ScriptModule from a file using torch::jit::load().
    module = torch::jit::load(argv[1]);
    module.eval();
    std::vector<torch::jit::IValue> inputs;
    inputs.push_back(torch::ones({1, 3, 224, 224}));
    at::Tensor output = module.forward(inputs).toTensor();
    auto bytes = torch::jit::pickle_save(output);
    std::ofstream fout("libtorch_out.zip", std::ios::out | std::ios::binary);
    fout.write(bytes.data(), bytes.size());
    fout.close();

    for (const auto& b : module.buffers()) {
        std::cout << b << std::endl;
    }
  }
  catch (const c10::Error& e) {
    std::cerr << "error loading the model\n";
    return -1;
  }

  std::cout << "ok\n";
}

>  6.5143
 6.5129
 6.5125
[ CPUFloatType{3} ]
 16.6550
 16.6079
 16.6480
[ CPUFloatType{3} ]
10
[ CPULongType{} ]
ok
# compare.py
import torch

libtorch_out = torch.load('libtorch_out.zip')
python_out = torch.load('../output.pth')

print((libtorch_out - python_out).abs().max())
> tensor(0., grad_fn=<MaxBackward1>)
print(torch.allclose(libtorch_out, python_out))
> True

As you can see, the Python and C++ API will both output the same running stats and will also yield the same output.

Could you post a reproducible code snippet, which shows that the running stats are not loaded?

1 Like

The first one is part of my Python code which load pth file and change all parameters of resnet-50 manually.
And the second one is part of my libtorch code loading pt file. And according to std::cout output, I cannot see the 4 parameters of BatchNormalization module, only can see bn.weight and bn.bias are the same, and afterward the parameters of following modules have been changed.

# import tensorflow as tf
import torch
import torchvision

from torchvision.models.resnet import ResNet, resnet50, resnet101
from torch.autograd import Variable

if __name__ == '__main__':
    print("%"*60)
    example = torch.ones(1, 3, 1500, 1500)
    example.requires_grad = False
    # example2 = torch.rand(1, 3, 320, 320)

    model_res50 = torchvision.models.resnet50(pretrained=True)
    model_res50_FE = torch.nn.Sequential(*list(model_res50.children())[:-3])
    model_res50_FE.eval()

    model_os2d = torch.load("os2d_v2-train.pth")
    # print("model_os2d", model_os2d['net']['net_feature_maps.conv1.weight'])

    # Transfer Os2dFEWeight to model_os2d_FE_res50
    # ==== layer 0 =========
    model_res50_FE.__dict__['_modules']['0'].weight = torch.nn.Parameter(model_os2d['net']['net_feature_maps.conv1.weight'])
    model_res50_FE.__dict__['_modules']['1'].weight = torch.nn.Parameter(model_os2d['net']['net_feature_maps.bn1.weight'])
    model_res50_FE.__dict__['_modules']['1'].bias = torch.nn.Parameter(model_os2d['net']['net_feature_maps.bn1.bias'])
    model_res50_FE.__dict__['_modules']['1'].running_mean = model_os2d['net']['net_feature_maps.bn1.running_mean']
    model_res50_FE.__dict__['_modules']['1'].running_var = model_os2d['net']['net_feature_maps.bn1.running_var']
    # model_res50_FE.__dict__['_modules']['1'].num_batches_tracked = model_os2d['net']['net_feature_maps.bn1.num_batches_tracked']
    # ======================

    # ==== layer 1 =========
    model_res50_FE.__dict__['_modules']['4'].__dict__['_modules']['0'].conv1.weight = torch.nn.Parameter(model_os2d['net']['net_feature_maps.layer1.0.conv1.weight'])
    model_res50_FE.__dict__['_modules']['4'].__dict__['_modules']['0'].bn1.weight = torch.nn.Parameter(model_os2d['net']['net_feature_maps.layer1.0.bn1.weight'])
    model_res50_FE.__dict__['_modules']['4'].__dict__['_modules']['0'].bn1.bias = torch.nn.Parameter(model_os2d['net']['net_feature_maps.layer1.0.bn1.bias'])
    model_res50_FE.__dict__['_modules']['4'].__dict__['_modules']['0'].bn1.running_mean = model_os2d['net']['net_feature_maps.layer1.0.bn1.running_mean']
    model_res50_FE.__dict__['_modules']['4'].__dict__['_modules']['0'].bn1.running_var = model_os2d['net']['net_feature_maps.layer1.0.bn1.running_var']
    # model_res50_FE.__dict__['_modules']['4'].__dict__['_modules']['0'].bn1.num_batches_tracked = model_os2d['net']['net_feature_maps.layer1.0.bn1.num_batches_tracked']

    model_res50_FE.__dict__['_modules']['4'].__dict__['_modules']['0'].conv2.weight = torch.nn.Parameter(model_os2d['net']['net_feature_maps.layer1.0.conv2.weight'])
    model_res50_FE.__dict__['_modules']['4'].__dict__['_modules']['0'].bn2.weight = torch.nn.Parameter(model_os2d['net']['net_feature_maps.layer1.0.bn2.weight'])
    model_res50_FE.__dict__['_modules']['4'].__dict__['_modules']['0'].bn2.bias = torch.nn.Parameter(model_os2d['net']['net_feature_maps.layer1.0.bn2.bias'])
    model_res50_FE.__dict__['_modules']['4'].__dict__['_modules']['0'].bn2.running_mean = model_os2d['net']['net_feature_maps.layer1.0.bn2.running_mean']
    model_res50_FE.__dict__['_modules']['4'].__dict__['_modules']['0'].bn2.running_var = model_os2d['net']['net_feature_maps.layer1.0.bn2.running_var']
    # model_res50_FE.__dict__['_modules']['4'].__dict__['_modules']['0'].bn2.num_batches_tracked = model_os2d['net']['net_feature_maps.layer1.0.bn2.num_batches_tracked']

    model_res50_FE.__dict__['_modules']['4'].__dict__['_modules']['0'].conv3.weight = torch.nn.Parameter(model_os2d['net']['net_feature_maps.layer1.0.conv3.weight'])
    model_res50_FE.__dict__['_modules']['4'].__dict__['_modules']['0'].bn3.weight = torch.nn.Parameter(model_os2d['net']['net_feature_maps.layer1.0.bn3.weight'])
    model_res50_FE.__dict__['_modules']['4'].__dict__['_modules']['0'].bn3.bias = torch.nn.Parameter(model_os2d['net']['net_feature_maps.layer1.0.bn3.bias'])
    model_res50_FE.__dict__['_modules']['4'].__dict__['_modules']['0'].bn3.running_mean = model_os2d['net']['net_feature_maps.layer1.0.bn3.running_mean']
    model_res50_FE.__dict__['_modules']['4'].__dict__['_modules']['0'].bn3.running_var = model_os2d['net']['net_feature_maps.layer1.0.bn3.running_var']
    # # model_res50_FE.__dict__['_modules']['4'].__dict__['_modules']['0'].bn3.num_batches_tracked = model_os2d['net']['net_feature_maps.layer1.0.bn3.num_batches_tracked']

    ...

    model_res50_FE_traced_cell = torch.jit.script(model_res50_FE, example)
    print(model_res50_FE.forward(example))
    print(model_res50_FE.forward(example).size())

    model_res50_FE_traced_cell.save("ResNet50_os2d_FE_1500_witlEval_withScript_Pretrained.pt") # TODO:```

torch::jit::script::Module model_body, model_body_que, model_head;
model_body = torch::jit::load("D:/code/python/one_shot/ResNet50_os2d_FE_1500_witlEval_withScript_Pretrained.pt");
model_body.to(device);
model_body.eval();

for(const auto& abc : model_body.attributes()){

    std::cout << abc.toTensor() << endl << "model_body.modules() module" << endl;
    system("pause");

}

I’ve tested to print out buffer in my c++ code, and it outputs the same values.
And I’ve found that other parameters are same, I was confused by difference between column first or row first rule of python and c++, I’m sorry about it, I misunderstood it.
But now I’m curious about whether buffers (such as running_mean and running var) works when module is in forward process.
And Thank you very much for helpful reply!

1 Like

Could you explain a bit what “forward process” means?
The parameters, buffers (and potentially other tensors) will be applied as they were defined in the forward method. I’m not sure I understand the question completely.

1 Like

oh, okay.
After model forward, I get an output which is the feature map of the image since I cut off average pooling and fc layer of resnet-50. And I sum all of elements in feature map tensor, and I get 90000 from python and 180000 from c++. I think it is not caused by float precision since the gap is too huge. Since all of parameters are same, I only can doubt whether the buffers are applied correctly when model runs forward process.
I’ve checked the resize result of c++ and python are same. The reason I’ve checked it is that someone said using PIL or opencv may cause the different results.

Again, thank you so much for your quick and helpful reply everytime!

Make sure to use the same input (e.g. torch.ones) in the same shape and call model.eval() for the Python and C++ application.
Also, check the output shapes and see, if there is any difference.
If you are still getting different outputs, try to scale down the problem, e.g. by using a standad ResNet without any manipulations.

1 Like

Oh, you mean that I should input same input like torch.ones first and then call model.eval() even I’ve already loaded .pt file in c++? Without that, is it not possible to initialize the model in c++?

After test your advice, it output the almost same result result with torch.ones and standard ResNet…

Using a constant input would remove the possibility of a different preprocessing or loading logic between your codes. It’s not needed to initialize the model.
Did you call model.eval() using your model as well? This would be necessary to compare the outputs, as e.g. dropout layers might be used in the models.

Could you use the same workflow (contant inputs and model.eval()) using your model and check the outputs again?

1 Like

I still find that the output is very different since I’ve checked element sum of output feature maps are different each other (for c++ 180000 and for python 90000). But due to the project deadline, I priorly translate other modules of the entire model.
I should use a constant input as you advice and check the dropout layers! And yes, I call model.eval() when I use my model as well.
It was a great help and I really appreciate it ! After doing the rest of the translation task, I will check all of your advice again!