Weird time measurement outputs and correct way of C++ inference

Hello. I wrote a prediction code for a classifier in C++. I am not experienced with C++ and new to libtorch. When model prediction print line exist timing output is like below:

Timing output when model prediction print line exits.

Time passed outside of function in seconds: 0.052
Time passed inside of function in seconds: 0.006

When model prediction print line does not exist timing output is like below:

Timing output when model prediction print line does not exists.

Time passed inside of function in seconds: 0.051
Time passed outside of function in seconds: 0.051

model prediction print line

std::cout << "Prediction is: " << prediction << std::endl;

My full code:

#include <iostream>
#include <torch/torch.h> 
#include <torch/script.h>
#include <opencv2/highgui/highgui.hpp>
#include <opencv2/imgproc.hpp>
#include <zmq.hpp> 
#include "metolib.hpp"
#include <chrono>
#include <unistd.h> // timelib



torch::Tensor classifier_predict_image(cv::Mat image, torch::jit::script::Module model)
{
    auto func_t0 = std::chrono::high_resolution_clock::now();

    torch::Tensor tensor_image = metolib::mat2tensor(image, image.size().height, image.size().width, image.channels()); 
    std::vector<torch::jit::IValue> input;
    input.push_back(tensor_image.to(torch::kCUDA));
    torch::Tensor output = model.forward(input).toTensor();
    output = torch::argmax(output);

    auto func_t1 = std::chrono::high_resolution_clock::now();
    auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(func_t1 - func_t0);
    std::cout << "Time passed inside of function in seconds: " << (float)duration.count() / 1000 << std::endl;

    return output;
} 


int main(int argc, const char *argv[])
{     
    // Inits
    torch::init_num_threads();
    cv::Mat frame;
    torch::Tensor prediction;

    // Load model
    torch::jit::script::Module model = torch::jit::load(argv[1]);
    model.eval(); 
    model.to(torch::kCUDA);

    // Video input
    cv::VideoCapture cap(0);

    while(cap.isOpened())
    {
        // read  frame
        cap >> frame;  



        // predict
        auto main_loop_t0 = std::chrono::high_resolution_clock::now();
        
        prediction = classifier_predict_image(frame, model);
        std::cout << "Prediction is: " << prediction << std::endl;

        auto main_loop_t1 = std::chrono::high_resolution_clock::now();
        auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(main_loop_t1 - main_loop_t0);
        std::cout << "Time passed outside of function in seconds: " << (float)duration.count() / 1000 << std::endl;

        cv::imshow("Output", frame);
        char key = cv::waitKey(1);
        if((int)key == 'q') {std::cout << "Q pressed. Exiting..." << std::endl; cv::destroyAllWindows(); break;}
    }

}

mat2tensor function:

torch::Tensor mat2tensor(cv::Mat image, int h, int w, int channels)
    {
        torch::Tensor out_tensor = torch::from_blob(image.data, {1, h, w, channels}).permute({0, 3, 1, 2});
        return out_tensor; 
    }

I followed the below tutorial for exporting Python trained model to C++.
https://pytorch.org/tutorials/advanced/cpp_export.html

My CMakeLists.txt:

cmake_minimum_required (VERSION 2.8)
project(example_project)
set( CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++14" )
set( OpenCV_DIR "/usr/share/OpenCV" )
set( PKG_CONFIG_PATH "/usr/lib/pkgconfig" )
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")

find_package(Torch REQUIRED)
find_package(OpenCV REQUIRED) 


add_executable(../bin/app ../sources/app.cpp)
include_directories(/home/m3/meto_ws/c_cpp_libs/libtorch/include/torch/csrc/api/include)
include_directories(../libs)
target_link_libraries(../bin/app "${TORCH_LIBRARIES}"  "${OpenCV_LIBS}")

 

set_property(TARGET ../bin/app PROPERTY CXX_STANDARD 14) 

I downloaded libtorch from pytorch.org using the below link.
https://download.pytorch.org/libtorch/cu111/libtorch-cxx11-abi-shared-with-deps-1.8.1%2Bcu111.zip

My Cuda version is 11.1.

My questions:
1 - Any idea why timing outputs are differs so much?
2 - Is my CMakeLists.txt correct?
3 - Is my code correct?

CUDA operations are executed asynchronously, so you would need to synchronize the code before starting and stopping the timers to get a valid profile of the runtime. The std::cout operation is most likely a synchronization point, since the result is needed so that it can be printed in the terminal.

Can you show me samples or a source to read about synchronization? I have no idea how to do it.

You can synchronize your code in libtorch via:

auto stream = c10::cuda::getCurrentCUDAStream(device.index());
C10_CUDA_CHECK(cudaStreamSynchronize(stream));

I added two lines to my code but no change.