Torch::jit::load in c++ and Torch.jit.load in python for object detection models give dissimilar results

I have noticed some discrepancy in object detection models when loaded in c++ vs python. I train the models in python before calling torch.jit.script(model) and saving the scripted model.

This is the script that I use to test the scripted model in python

model = torch.jit.load("scripted_model_fasterrcnn.pt", map_location='cpu')
image = cv2.imread("./sample_image.png")
image = cv2.resize(image, dsize=(320, 320))
model.eval()
transform = torchvision.transforms.Compose([
     torchvision.transforms.ToTensor()
])
image = transform(image)
print(model([image]))

The output from the model is as follows:

({}, [{'boxes': tensor([[111.5199,  48.1314, 165.9489,  92.9383],
        [ 24.5781,  26.3587, 291.7290, 269.1508],
        [200.7294,  65.7395, 210.2078,  91.2619],
        [ 96.1484,  96.5067, 145.2950, 139.3602],
        [ 96.5121,  99.2410, 218.9833, 142.0360],
        [152.3154, 215.1508, 172.9473, 256.4173],
        [ 74.7000,  61.1955,  97.1016,  90.3335],
        [ 47.2993,  77.8948,  79.3574, 106.0734],
        [ 42.6876, 118.5618,  62.9411, 158.2090],
        [154.7034, 117.0237, 219.6645, 138.5528],
        [225.9446, 173.9281, 247.5028, 219.6541],
        [224.7263,  95.1853, 257.9039, 128.9357],
        [176.1919, 202.1530, 207.5209, 238.8022],
        [179.5007, 199.3443, 204.4278, 235.2946],
        [235.8262, 144.1849, 263.2667, 180.3906],
        [125.7185,  54.9373, 167.9530,  91.9837],
        [ 43.0586, 116.2135,  63.8451, 157.1656],
        [ 65.7174, 159.9202,  89.8124, 202.1086],
        [122.1806, 110.5731, 210.8495, 139.0777],
        [ 63.2575, 159.2407,  92.9831, 202.0807],
        [235.7154, 146.1310, 262.6613, 179.0983],
        [222.2111,  93.1336, 258.3600, 128.4226],
        [111.7222,  46.9001, 121.2929,  87.5277],
        [ 96.2667, 204.5216, 130.4216, 233.2910],
        [158.1481,  92.5807, 254.6094, 132.3697],
        [155.5057, 203.3590, 209.4122, 259.7876],
        [ 74.0674,  62.7407,  96.0868,  91.1611],
        [ 92.2272,  48.6413, 171.6691,  93.8465],
        [176.9133, 201.2077, 209.6607, 236.1525],
        [218.8239,  93.3167, 254.8027, 126.9204],
        [ 96.1186, 205.9873, 129.3002, 233.0285],
        [ 45.7458,  77.8189,  79.6607, 106.9923],
        [ 94.7852, 205.3053, 126.1903, 229.0898],
        [225.9858, 184.0837, 246.4686, 220.3966],
        [223.9675, 177.2893, 249.1340, 217.3411],
        [ 62.5327, 159.2016,  95.3400, 203.9700],
        [119.9844,  57.4481, 169.6862,  92.2951],
        [224.6748, 149.6771, 261.5659, 212.7879],
        [ 99.8080, 204.5827, 131.1629, 232.3792],
        [177.7803, 199.8947, 206.4939, 241.7094],
        [ 64.6609, 160.5251,  94.8401, 203.9768],
        [230.0610, 148.3878, 263.5519, 212.8003],
        [227.0596, 152.7305, 254.4798, 218.7213],
        [152.1518, 219.0876, 171.5103, 257.3699],
        [ 45.1598,  79.2185,  77.0206, 107.5901],
        [149.8237, 215.0952, 170.7960, 257.7943],
        [217.8964, 150.4392, 263.4274, 210.8090],
        [139.8701, 114.6050, 214.3641, 135.9767],
        [100.0713,  94.3219, 202.3020, 140.7887],
        [ 64.4861, 161.6863,  92.4369, 204.7435],
        [ 80.9625,  67.3800, 217.4904, 138.9704],
        [223.4836, 177.3780, 246.8013, 219.4257],
        [115.7386,  49.3352, 167.7150,  87.9365],
        [  5.5940, 106.9003, 237.6417, 281.9904],
        [ 94.2908, 204.3514, 128.3428, 232.0317],
        [ 64.0087, 158.4752,  91.7718, 199.2004],
        [ 66.0727, 159.8815,  94.2593, 205.3629],
        [ 64.6510, 156.7599,  92.1091, 205.8702],
        [222.4584, 144.8448, 261.8528, 179.5527],
        [225.1163, 178.2792, 246.1298, 220.4667],
        [ 64.9479, 159.1237,  94.0426, 206.7240],
        [235.6259, 147.2666, 260.5404, 178.6416],
        [164.5046, 196.8825, 207.9899, 239.3440],
        [ 42.9777, 113.7448,  63.2752, 160.4895],
        [ 47.3674,  81.6245,  79.3286, 108.0966],
        [ 79.8154,  93.9216, 165.5143, 136.4308],
        [120.5117, 102.1464, 244.4896, 141.8823],
        [229.2781, 150.4267, 264.9536, 209.5770],
        [146.6634, 115.0739, 215.5014, 139.4496],
        [235.8229, 145.7525, 260.6239, 178.3852],
        [176.9435, 198.8155, 207.2011, 237.8320],
        [111.4658,  48.5251, 165.6438,  89.0411],
        [221.3267,  94.7368, 257.3989, 124.7556],
        [ 44.4494, 116.1530,  62.2219, 160.8552],
        [ 77.9704,  60.0059,  96.5681,  92.4873],
        [226.7721, 150.4558, 262.2142, 215.6963],
        [173.6458, 198.4269, 205.6388, 239.3172],
        [222.6702, 155.9789, 256.7563, 217.2549],
        [145.0301, 115.5764, 210.5471, 137.2198],
        [171.8129, 197.5060, 206.3469, 241.4117],
        [ 74.3697,  60.7232,  96.1948,  91.3447],
        [228.7881, 168.0770, 270.0812, 219.7723],
        [ 45.3199,  78.4097,  80.9535, 104.3143],
        [ 44.8993,  78.3313,  54.2288, 100.2458],
        [208.5768,  95.6042, 263.6136, 125.4095],
        [121.7834,  56.3466, 176.2651,  90.1986],
        [148.9704, 116.7540, 220.2704, 138.1128],
        [ 94.8379, 205.5583, 130.8381, 231.7617]], grad_fn=<StackBackward>), 'labels': tensor([12, 13,  1, 15, 14,  6, 11, 10,  9, 16,  6,  2,  3, 15,  3,  2,  7, 15,
        14,  7,  5,  3,  1,  3, 14, 14,  9, 14,  5, 12,  7, 12, 15,  4,  5,  9,
         3,  5,  2,  2,  8, 16,  6,  5, 11,  8, 14, 12, 15,  3, 14,  8, 15, 13,
         4,  4,  5, 12, 14,  3,  2,  7, 14,  4,  9, 14, 16,  4,  3,  2,  6,  4,
         5, 15,  1,  7,  7,  3, 15, 12,  4, 14,  8,  1, 14, 16,  5,  5]), 'scores': tensor([0.9840, 0.9696, 0.9655, 0.9570, 0.9495, 0.9191, 0.9101, 0.9074, 0.8563,
        0.8375, 0.7017, 0.6758, 0.6390, 0.5955, 0.5732, 0.5603, 0.5590, 0.5259,
        0.5190, 0.5073, 0.5038, 0.4409, 0.4315, 0.3957, 0.3925, 0.3746, 0.3384,
        0.3298, 0.3157, 0.2993, 0.2872, 0.2833, 0.2648, 0.2647, 0.2576, 0.2301,
        0.2174, 0.2119, 0.2012, 0.1873, 0.1838, 0.1754, 0.1706, 0.1653, 0.1609,
        0.1571, 0.1569, 0.1370, 0.1356, 0.1304, 0.1230, 0.1214, 0.1163, 0.1112,
        0.1106, 0.1051, 0.1034, 0.0989, 0.0987, 0.0952, 0.0917, 0.0905, 0.0886,
        0.0879, 0.0879, 0.0869, 0.0839, 0.0803, 0.0795, 0.0793, 0.0766, 0.0755,
        0.0747, 0.0716, 0.0647, 0.0641, 0.0620, 0.0604, 0.0585, 0.0582, 0.0558,
        0.0557, 0.0548, 0.0525, 0.0518, 0.0517, 0.0517, 0.0513],
       grad_fn=<IndexBackward>)}])

The scores are high and labels correspond to those found in the image for the python loaded model.

However, when I load the model in c++ using the torch::jit::load module using the below script and running on the same image, I get a different result.


#include <torch/script.h> // One-stop header.
#include <torchvision/vision.h>
#include <iostream>
#include <memory>
#include <opencv2/core.hpp>
#include <opencv2/imgcodecs.hpp>
#include <opencv2/highgui.hpp>
#include <opencv2/imgproc.hpp>

using namespace cv;

int main(int argc, const char* argv[]) {
  if (argc != 3) {
    std::cerr << "usage: example-app <path-to-exported-script-module> <path to image>\n";
    return -1;
  }

  torch::jit::script::Module module;
  // try {
    // Deserialize the ScriptModule from a file using torch::jit::load().
  module = torch::jit::load(argv[1]);
  module.eval();
  // }
  // catch (const c10::Error& e) {
  //   std::cerr << "error loading the model\n";
  //   return -1;
  // }

  Mat img = imread(argv[2], IMREAD_COLOR);
  resize(img, img, Size(320, 320));

  namedWindow( "Display window", WINDOW_AUTOSIZE);
  imshow("Display window", img);
  waitKey(0);

  at::Tensor tensor_image = torch::from_blob(img.data, { 3, img.rows, img.cols}, at::kByte);

  at::Tensor tensor = tensor_image.toType(c10::kFloat).div(255);

  std::vector<torch::Tensor> inputs;
  // at::Tensor tensor = torch::ones({3, 320, 320})
  inputs.push_back(tensor);

  auto outputTuple = module.forward({at::TensorList(inputs)}).toTuple();

  std::cout << "ok\n";

  std::cout << outputTuple->elements()[1].toList().get(0).toGenericDict();
}

The printed output from the c++ module is as follows:

{boxes:  205.5799    6.2811  227.4993   66.4494
  88.8086   21.2261  238.2052  309.1224
 197.4220    1.8514  311.3408   81.0451
   0.0000   64.1860  170.1354  311.5446
  83.0187  147.5506  312.8422  307.5482
  71.2753    0.0000   85.0268   29.9290
  96.1392  267.6833  119.3408  309.8810
   0.0000    3.9964  207.4197  198.1970
 136.5122   38.4333  172.6665   67.5004
 101.1439  126.1354  147.7678  155.5787
 287.5102  292.9289  320.0000  320.0000
 293.8416   32.3060  320.0000   52.3225
 185.8121  249.4510  232.8698  296.0813
 152.3779   58.6546  320.0000  313.7756
  87.7302  123.1198  147.1136  159.4511
  76.8865    0.0731  151.0789   43.7236
 202.9545  160.2306  239.7407  222.1995
 170.8793   82.5652  186.4131  145.5943
  10.8631  107.5374  258.6507  277.9001
 289.7677  205.6405  306.5864  267.3916
 283.7892  296.6545  320.0000  319.8353
   0.0000   47.4983   12.8185   74.0877
 246.7242  264.9669  275.9197  299.8090
 289.8707    0.0000  318.3735   24.2169
 289.2394  202.5244  306.8430  263.7386
  57.0391  239.6176   77.3420  293.2146
 239.3128  265.3953  283.2209  302.6130
 285.8833  199.9184  305.3768  263.6747
 170.5636   81.9955  189.4623  148.0602
 283.2243  202.2214  307.5496  267.4892
 141.4374  239.7231  163.1488  299.6705
 244.0393  269.3799  282.4297  302.4154
 135.7289   35.5234  175.7074   70.3571
  54.1985  212.2170   69.0485  254.8595
 231.2240  138.4161  319.2614  173.3159
 117.9162    0.2215  141.9003   34.5996
  69.1365   62.8962   81.2370  133.4856
 200.4000  166.0829  238.2242  185.7499
   0.0000   48.2016   11.0690   78.6971
 286.1261    0.0000  300.9700   22.9680
   2.0455    2.8522   83.0857   34.8119
  48.2377  184.3904   80.8564  203.1248
 289.8303    0.0000  319.2095   47.0337
 204.0433    0.1787  310.6933   65.3949
 287.4706    0.2284  301.7375   23.4463
  76.7976   53.7115  113.4441   69.8022
 293.9149   31.3739  319.2691   52.6524
 289.1787    0.9553  301.8062   21.5435
 145.2040  247.9566  161.1664  296.7121
  54.4559  213.4805   68.5000  253.0465
  81.6588    0.0000  113.1224   24.5262
 289.3730    0.1651  318.8288   21.1703
   1.0660  292.7904   32.1888  312.4598
 288.5013    0.0000  302.3147   25.4572
  76.1369    0.0000  109.1329   28.4577
  83.8482    1.1054  121.8937   46.5508
  80.9367  203.9415  100.4512  250.8624
 295.9372   30.4247  316.6146   48.4681
  51.5779  215.7196   79.1016  298.0002
 242.8738  130.8215  317.9763  167.6260
 286.8958    0.0000  320.0000   47.9069
  51.3604  216.0111   76.5821  297.2584
  81.2741  188.3169  118.2954  205.3452
 287.9591    0.0000  305.4834   23.5711
 246.7157  193.4644  280.3631  271.3096
 250.7676  195.0680  270.1045  240.3511
  54.5070  212.4253   68.8881  254.8246
  56.6173  210.5370   74.5919  262.3492
  82.4888    0.0000  112.4623   24.7118
 282.8600  295.9117  318.2802  319.6000
 172.0242   80.4292  193.0759  148.1853
 156.4728   13.4416  172.0676   38.0945
   0.0000   46.8211   12.2823   77.8516
 245.4109  188.0984  280.5365  264.8040
 247.2911  185.7036  278.7738  269.1025
 135.4207   35.6916  166.4271   68.0820
  52.2134  237.3519   75.1517  296.9148
  82.2377    0.0000  113.0828   24.7477
  90.9985  158.6708  118.6184  188.4403
  48.7953  185.0795   83.5775  205.4890
 132.8539   35.5799  173.2599   70.8452
  86.4917   29.3484  129.3071   44.4802
 145.0171  260.0258  159.2331  301.5958
 227.4105   92.2210  320.0000  116.7067
  75.1366   54.0580  110.6670   68.8280
 287.1800    0.0634  300.7803   22.5780
   2.9772  103.8928   76.8366  139.0727
 172.0433   84.2483  193.3882  151.5868
  48.2970  184.3229   81.3700  204.5005
 133.2772   38.6679  176.9788   71.6648
   1.0222    3.8745   72.9432   33.9982
 138.8353   35.1738  179.3036   72.4320
  95.6790  126.5814  147.4179  155.7961
 156.9208   14.3203  172.5015   37.7791
   0.0000   49.0658   12.4783   74.6269
 144.0489  242.5912  167.5690  302.5341
 173.9803   78.1983  190.5706  144.0554
 207.2380    2.6587  312.0267   61.4207
  68.3068    0.0000   84.3640   29.5561
 287.6278  294.9377  318.8213  320.0000
[ CPUFloatType{100,4} ], labels:  15
 13
 14
 13
 13
  1
 15
 13
  4
 16
  3
  3
 15
 13
 14
 14
 15
  7
 13
  5
  5
  8
 15
 12
  6
  7
  3
  7
  1
  1
  1
  2
  3
  9
 14
 15
  1
 15
  4
  1
 14
  2
 15
 15
  3
  2
  2
  6
  6
  7
  3
  4
  3
  2
 12
 15
 15
 15
  4
 15
  3
  9
  2
  5
  7
 15
  1
 15
 16
  4
  6
  3
  9
  9
 15
 15
  3
  2
 15
  4
  7
  3
  1
 14
 15
  7
 15
  5
 12
  5
 15
  2
  3
  7
  5
  5
 15
 16
  7
  2
[ CPULongType{100} ], scores:  0.6805
 0.5947
 0.5202
 0.5133
 0.4574
 0.4370
 0.4332
 0.4081
 0.3999
 0.3866
 0.3424
 0.3407
 0.3351
 0.3343
 0.3180
 0.3104
 0.2920
 0.2843
 0.2832
 0.2828
 0.2801
 0.2799
 0.2678
 0.2649
 0.2510
 0.2486
 0.2375
 0.2351
 0.2344
 0.2311
 0.2269
 0.2251
 0.2208
 0.2132
 0.2065
 0.2016
 0.2010
 0.1989
 0.1978
 0.1953
 0.1933
 0.1846
 0.1759
 0.1718
 0.1714
 0.1704
 0.1684
 0.1655
 0.1654
 0.1570
 0.1529
 0.1526
 0.1524
 0.1465
 0.1442
 0.1434
 0.1413
 0.1411
 0.1393
 0.1370
 0.1323
 0.1304
 0.1302
 0.1274
 0.1272
 0.1268
 0.1260
 0.1240
 0.1234
 0.1221
 0.1201
 0.1198
 0.1193
 0.1191
 0.1174
 0.1170
 0.1161
 0.1157
 0.1152
 0.1126
 0.1124
 0.1122
 0.1120
 0.1110
 0.1091
 0.1081
 0.1062
 0.1055
 0.1028
 0.1016
 0.1004
 0.0995
 0.0989
 0.0982
 0.0960
 0.0926
 0.0921
 0.0916
 0.0898
 0.0897

The outputs of both models are quite different with the c++ one performing poorly. The faster rcnn model I am using is using resnet-50 backbone with fpn. I also experimented with retinanet model, mobilenet v2 and mobilenet v3 models with backbone quantised. None were able to give similar results to the torchscript model loaded in python.

I am not sure what is going on as there was no error in saving and loading the torchscripted models. Any advice would be greatly appreciated.

EDIT:
Torch.ones for both models return different outputs as well.
Pytorch version: 1.9
Torchvision version: 0.10
Libtorch: Preview Nightly MacOS

I did another test of converting the fasterrcnn_mobilenet_v3_large_320_fpn pretrained model into a torchscript module. The reproducible code is available here.

from torchvision.models.detection.faster_rcnn import fasterrcnn_mobilenet_v3_large_320_fpn

model = fasterrcnn_mobilenet_v3_large_320_fpn(pretrained=True)
model.eval()
image = cv2.imread("./test1.png")
image = cv2.resize(image, dsize=(320, 320))
transform = torchvision.transforms.Compose([
     torchvision.transforms.ToTensor()
])
image = transform(image)
script_model = torch.jit.script(model)
torch.jit.save(script_model, "scripted_model_mobilenetv3_large_320_fpn.pt")
print(script_model([image]))

Python Output:

({}, [{'boxes': tensor([[ 57.2329,  61.6968, 240.5976, 270.2655],
        [ 49.8475, 110.0096, 137.2613, 300.0203],
        [191.1161,  39.0309, 286.9027,  94.8667],
        [188.2992,  36.5230, 288.7202,  97.4937],
        [ 53.7912,  65.7180, 106.5928, 123.5672],
        [ 54.1852, 112.7328, 134.0000, 302.0532],
        [ 30.1234,  80.8505,  97.9023, 221.6472],
        [ 27.9103,  47.1386,  42.7512,  66.0757],
        [ 50.6809,  45.9808, 110.9343, 113.8588],
        [166.8281,  51.4230, 193.9323,  72.8806],
        [  1.9139,   1.3991,  28.6922, 141.1715],
        [ 34.4260,  77.5499, 105.4713, 233.3638],
        [ 28.4753,  50.9322,  48.0249,  68.2570],
        [158.8731,  53.8971, 182.2708,  70.5178]], grad_fn=<StackBackward>), 'labels': tensor([ 2, 18,  3,  8,  2, 17, 62,  3, 64,  3, 72,  2,  4,  3]), 'scores': tensor([0.9982, 0.7730, 0.6875, 0.3241, 0.2950, 0.1868, 0.1649, 0.0912, 0.0715,
        0.0705, 0.0608, 0.0604, 0.0557, 0.0547], grad_fn=<IndexBackward>)}])

Using the C++ code above:
C++ Output

{boxes:    0.0000    0.0000  320.0000  320.0000
  36.9277    2.1780  264.9876  307.3490
  11.3890    0.0000  319.6223  316.5984
[ CPUFloatType{3,4} ], labels:  72
  1
 82
[ CPULongType{3} ], scores:  0.1217
 0.0593
 0.0533
[ CPUFloatType{3} ]}

test1 image from D2Go repository

I believe this line of code:

at::Tensor tensor_image = torch::from_blob(img.data, { 3, img.rows, img.cols}, at::kByte);

would interleave the tensor, since OpenCV would load the image in a channels-last memory layout, while you are loading it to a channels-first tensor_image, so you might need to load it as channels-last and permute it afterwards.
While this might be one issue, I don’t know why constant values (torch.ones) would also give other outputs, so there might be another error in the code.

Thank you for your reply. It seems that the channel dims was the problem.