Python && C++ result consistency

I am running a Python3 trained, torch_script exported CNN with C++.
However, I can not reproduce results on testing data.

The class labels the classifier returns in respective language are opposite for the identical data patch. Not sure if some memory is corrupted in the C++ implementation, but I was able to validate at least the Python implementation gives correct results. How else shall I debug? Right now I am trying to compare the weight matrices of the conv kernels to validate the model is consistent among both implementations. Note - patch stats are identical between both languages, could be that the memory is misaligned in C++?

Anything else:

Python export script:

    cnn = CNN1()

    cnn.load_state_dict(torch.load(args.cnn))

    torch_script_nn = torch.jit.script(cnn)

    script_file_name = args.cnn.replace('.torch', '.pytorch_script')
    torch_script_nn.save(script_file_name)

    print('{} CONVERTED TO {} '.format(args.cnn, script_file_name))

Python test script

    cnn = CNN1()
    cnn.load_state_dict(torch.load(cnn_file))

    i = 100
    j = 100
    patch_size = 14
    img = np.load(f)['arr_0']
    # plt.imshow(img, cmap='gray');plt.show()

    patch = img[i: i+ patch_size, j:j+patch_size]
    print('Patch max : {}, min {}'.format(np.max(patch), np.min(patch)))

    plt.imshow(patch, cmap='gray');plt.show()
    patch = np.expand_dims(patch, axis=2)
    normalized_patch = (patch - np.min(patch)) / np.max(patch - np.min(patch))
    print('Normalized Patch max : {}, min {}'.format(np.max(normalized_patch), np.min(normalized_patch)))
    torch_stack = torch.stack([torch.Tensor(normalized_patch)])
    print(torch_stack.shape)
    res = cnn(torch_stack.permute(0, 3, 1, 2))

    print('Class label {}'.format(res))

Prints

Patch max : 16.035221099853516, min 11.432785987854004
Normalized Patch max : 1.0, min 0.0
torch.Size([1, 14, 14, 1])
tensor([[9.9995e-01, 5.4011e-05]], grad_fn=<SoftmaxBackward>)

C++ test

std::cout << "TESTING CNN" << std::endl;
    std::size_t i = 100;
    std::size_t j = 100;

    const cv::Rect roi(i, j, patchSize, patchSize);
    cv::Mat amplitudePatch = amplitude(roi);
    double min;
    double max;
    cv::minMaxLoc(amplitudePatch, &min, &max);
    std::cout << "Patch Max : " << max << std::endl;
    std::cout << "Patch Min " << min << std::endl;

    cv::imshow("Patch", amplitudePatch);
    cv::Mat normalizedPatch = normalizeAmplitudePatch(amplitudePatch);
    cv::imshow("NPatch", normalizedPatch);
    cv::minMaxLoc(normalizedPatch, &min, &max);
    std::cout << "Normalized Patch Max : " << max << std::endl;
    std::cout << "Normalized Patch Min " << min << std::endl;

    cv::waitKey(0); 
    auto options = torch::TensorOptions().dtype(torch::kF32);

    std::cout << 1 << " " << amplitudePatch.size().height<< " " <<
        amplitudePatch.size().width<< " " << amplitudePatch.channels() << std::endl;

    torch::Tensor patchTensor = torch::from_blob(
      amplitudePatch.data, 
      {1, amplitudePatch.size().height,
        amplitudePatch.size().width, amplitudePatch.channels()},
      options);

    std::vector<c10::IValue> inputs;
    inputs.push_back(patchTensor.permute({0, 3, 1, 2}));
    torch::Tensor result = cnn.forward(inputs).toTensor();

    std::cout << result << std::endl;

Prints

Patch Max : 16.0352
Patch Min : 11.4328
Could not initialize OpenGL for RasterGLSurface, reverting to RasterSurface.
Could not initialize OpenGL for RasterGLSurface, reverting to RasterSurface.
Normalized Patch Max : 1
Normalized Patch Min 0
1 14 14 1
 0  1
[ Variable[CPUFloatType]{1,2} ]

A typical mistake people make is to forget that OpenCV is BGR while TorchVision uses RGB. So you would need to use .flip on the channel dim.

Best regards

Thomas

Thanks for the suggestion but I am using a grayscale imgs
channel == (x,y,1)
. Any other ideas?

Did you try to copy the data:

const cv::Rect roi(i, j, patchSize, patchSize);
cv::Mat amplitudePatch = amplitude(roi);

Add clone

cv::Mat amplitudePatch = amplitude(roi).clone();
1 Like

Clone / copy the patch here indeed changes the outcome of the function to correct value. I now have to pipe more examples through it to fully validate, but looks promising… Thanks a lot! :slight_smile:

Any explanation for your solution? Seems the memory is messed up at some point in my logic, but where exactly? The cv::show piece displays consistently with python.

Thanks again!

Yep, this is how OpenCV manages data when you use the operator () to create an ROI no matrix data is copied, this is for optimization. But when you try to access amplitudePatch.data, the data ptr is not what you want because this is still the image (amplitude). When you use show or save, OpenCV correctly manages the data index.

https://docs.opencv.org/2.4/modules/core/doc/basic_structures.html#id6

If you get the strides from OpenCV, you can use the strided variant of from_blob.

Or you could convert first and then use narrow to get the patch.

Best regards

Thomas

Vielen Dank für die Tips! :wink: