Forward function slows down after several iterations

Hello everyone,

I’m implementing a network in C++ using libtorch in Qt. I successfully built and trained a network, but I have now a problem while trying to deploy it.

To summarize my issue, it seems that when I loop over the forward function of my model, it runs well for several iterations, then suddenly run around 40x slower.

Here is an example of what I get using a simple network.

main.cpp:

#include <torch/torch.h>
#include <torch/script.h>
#include <QDebug>
#include <QElapsedTimer>


// Simple 5 linear layers network structure
struct NetImpl : torch::nn::Module {

    torch::nn::Linear linear_1;
    torch::nn::Linear linear_2;
    torch::nn::Linear linear_3;
    torch::nn::Linear linear_4;
    torch::nn::Linear linear_5;

    NetImpl(std::vector<int64_t> linear_dim_in, std::vector<int64_t> linear_dim_out)
      : linear_1(linear_dim_in[0], linear_dim_out[0]),
        linear_2(linear_dim_in[1], linear_dim_out[1]),
        linear_3(linear_dim_in[2], linear_dim_out[2]),
        linear_4(linear_dim_in[3], linear_dim_out[3]),
        linear_5(linear_dim_in[4], linear_dim_out[4])
    {
        register_module("linear_1", linear_1);
        register_module("linear_2", linear_2);
        register_module("linear_3", linear_3);
        register_module("linear_4", linear_4);
        register_module("linear_5", linear_5);
    }

    torch::Tensor forward(torch::Tensor x)
    {
        x = torch::flatten(x, 1, -1); // Flatten

        x = linear_1->forward(x);
        x = linear_2->forward(x);
        x = linear_3->forward(x);
        x = linear_4->forward(x);
        x = linear_5->forward(x);

        return x;
    }
};
TORCH_MODULE(Net); // creates module holder for NetImpl



int main()
{

    // Detect CUDA device
    torch::Device device("cpu");
    if (torch::cuda::is_available())
    {
        device = torch::Device("cuda:0");
    }

    // Declare network build variables
    std::vector<int64_t> linear_dim_in;
    linear_dim_in.push_back(72960);
    linear_dim_in.push_back(1024);
    linear_dim_in.push_back(512);
    linear_dim_in.push_back(256);
    linear_dim_in.push_back(128);

    std::vector<int64_t> linear_dim_out;
    linear_dim_out.push_back(1024);
    linear_dim_out.push_back(512);
    linear_dim_out.push_back(256);
    linear_dim_out.push_back(128);
    linear_dim_out.push_back(3);

    // Create a network
    Net net(linear_dim_in, linear_dim_out);

    // Set the model in eval mode
    net->eval();

    // Set the device in GPU
    net->to(device);


    // Declare a Timer and the desired loop number
    QElapsedTimer timer;
    int loop_nbr = 100;

    {
        // Disable grads computation. If not, grad values accumulate in GPU memory.
        torch::NoGradGuard no_grad;

        // For loop_nbr times
        for (int i = 0; i <= loop_nbr; i++)
        {
            qDebug() << "\n" << "Loop nbr.: " << i;

            // Start timer
            timer.restart();

            // Declare a random input
            torch::Tensor X_batch = torch::rand({1,1,240,304}, device);

            // Show the elapsed time
            qDebug() << "T_tensorDeclaration: " << timer.nsecsElapsed() << "ns";

            // Run the model on the input data.
            torch::Tensor prediction = net->forward(X_batch);

            // Show the elapsed time
            qDebug() << "T_forward: " << timer.nsecsElapsed() << "ns";
        }
    }

    return 0;
}

project.pro:

QT       += core gui

greaterThan(QT_MAJOR_VERSION, 4): QT += widgets

CONFIG += c++11

CONFIG += no_keywords

# The following define makes your compiler emit warnings if you use
# any Qt feature that has been marked deprecated (the exact warnings
# depend on your compiler). Please consult the documentation of the
# deprecated API in order to know how to port your code away from it.
DEFINES += QT_DEPRECATED_WARNINGS

SOURCES += main.cpp

INCLUDEPATH += $$PWD/../../libtorch/include
DEPENDPATH += $$PWD/../../libtorch/include

INCLUDEPATH += $$PWD/../../libtorch/include/torch/csrc/api/include
DEPENDPATH += $$PWD/../../libtorch/include/torch/csrc/api/include

LIBS += -L$$PWD/../../libtorch/lib/ -ltorch -lc10

Results:

20:08:08: Debugging starts

20:26:42: Debugging starts

 Loop nbr.:  0
T_tensorDeclaration:  73586 ns
T_forward:  672337 ns

 Loop nbr.:  1
T_tensorDeclaration:  16926 ns
T_forward:  106928 ns

 Loop nbr.:  2
T_tensorDeclaration:  12575 ns
T_forward:  96994 ns

 Loop nbr.:  3
T_tensorDeclaration:  12155 ns
T_forward:  95574 ns

 Loop nbr.:  4
T_tensorDeclaration:  11896 ns
T_forward:  95337 ns

 Loop nbr.:  5
T_tensorDeclaration:  11910 ns
T_forward:  95338 ns

 Loop nbr.:  6
T_tensorDeclaration:  11818 ns
T_forward:  94641 ns

 Loop nbr.:  7
T_tensorDeclaration:  11889 ns
T_forward:  95124 ns

 Loop nbr.:  8
T_tensorDeclaration:  11942 ns
T_forward:  95305 ns

 Loop nbr.:  9
T_tensorDeclaration:  11877 ns
T_forward:  94863 ns

 Loop nbr.:  10
T_tensorDeclaration:  11875 ns
T_forward:  94936 ns

 Loop nbr.:  11
T_tensorDeclaration:  11967 ns
T_forward:  94624 ns

 Loop nbr.:  12
T_tensorDeclaration:  11690 ns
T_forward:  94956 ns

 Loop nbr.:  13
T_tensorDeclaration:  11882 ns
T_forward:  94759 ns

 Loop nbr.:  14
T_tensorDeclaration:  11825 ns
T_forward:  94982 ns

 Loop nbr.:  15
T_tensorDeclaration:  11894 ns
T_forward:  94875 ns

 Loop nbr.:  16
T_tensorDeclaration:  11854 ns
T_forward:  95101 ns

 Loop nbr.:  17
T_tensorDeclaration:  11891 ns
T_forward:  94930 ns

 Loop nbr.:  18
T_tensorDeclaration:  11838 ns
T_forward:  94664 ns

 Loop nbr.:  19
T_tensorDeclaration:  11817 ns
T_forward:  94819 ns

 Loop nbr.:  20
T_tensorDeclaration:  11778 ns
T_forward:  94940 ns

 Loop nbr.:  21
T_tensorDeclaration:  11930 ns
T_forward:  95169 ns

 Loop nbr.:  22
T_tensorDeclaration:  11840 ns
T_forward:  95371 ns

 Loop nbr.:  23
T_tensorDeclaration:  11811 ns
T_forward:  94682 ns

 Loop nbr.:  24
T_tensorDeclaration:  11885 ns
T_forward:  94658 ns

 Loop nbr.:  25
T_tensorDeclaration:  11908 ns
T_forward:  94840 ns

 Loop nbr.:  26
T_tensorDeclaration:  11973 ns
T_forward:  100463 ns

 Loop nbr.:  27
T_tensorDeclaration:  13257 ns
T_forward:  98160 ns

 Loop nbr.:  28
T_tensorDeclaration:  11934 ns
T_forward:  95079 ns

 Loop nbr.:  29
T_tensorDeclaration:  11860 ns
T_forward:  95174 ns

 Loop nbr.:  30
T_tensorDeclaration:  11840 ns
T_forward:  95035 ns

 Loop nbr.:  31
T_tensorDeclaration:  11814 ns
T_forward:  96194 ns

 Loop nbr.:  32
T_tensorDeclaration:  11870 ns
T_forward:  95375 ns

 Loop nbr.:  33
T_tensorDeclaration:  11822 ns
T_forward:  95376 ns

 Loop nbr.:  34
T_tensorDeclaration:  11726 ns
T_forward:  94956 ns

 Loop nbr.:  35
T_tensorDeclaration:  11861 ns
T_forward:  95158 ns

 Loop nbr.:  36
T_tensorDeclaration:  11707 ns
T_forward:  95276 ns

 Loop nbr.:  37
T_tensorDeclaration:  11738 ns
T_forward:  95323 ns

 Loop nbr.:  38
T_tensorDeclaration:  11945 ns
T_forward:  95335 ns

 Loop nbr.:  39
T_tensorDeclaration:  11914 ns
T_forward:  95224 ns

 Loop nbr.:  40
T_tensorDeclaration:  11771 ns
T_forward:  95094 ns

 Loop nbr.:  41
T_tensorDeclaration:  11677 ns
T_forward:  95076 ns

 Loop nbr.:  42
T_tensorDeclaration:  11944 ns
T_forward:  95172 ns

 Loop nbr.:  43
T_tensorDeclaration:  11729 ns
T_forward:  95244 ns

 Loop nbr.:  44
T_tensorDeclaration:  11973 ns
T_forward:  95500 ns

 Loop nbr.:  45
T_tensorDeclaration:  11858 ns
T_forward:  95404 ns

 Loop nbr.:  46
T_tensorDeclaration:  11948 ns
T_forward:  95370 ns

 Loop nbr.:  47
T_tensorDeclaration:  11728 ns
T_forward:  95435 ns

 Loop nbr.:  48
T_tensorDeclaration:  12010 ns
T_forward:  95462 ns

 Loop nbr.:  49
T_tensorDeclaration:  11789 ns
T_forward:  95408 ns

 Loop nbr.:  50
T_tensorDeclaration:  11750 ns
T_forward:  95111 ns

 Loop nbr.:  51
T_tensorDeclaration:  11722 ns
T_forward:  95340 ns

 Loop nbr.:  52
T_tensorDeclaration:  11809 ns
T_forward:  95448 ns

 Loop nbr.:  53
T_tensorDeclaration:  11802 ns
T_forward:  95283 ns

 Loop nbr.:  54
T_tensorDeclaration:  11733 ns
T_forward:  94973 ns

 Loop nbr.:  55
T_tensorDeclaration:  11885 ns
T_forward:  95266 ns

 Loop nbr.:  56
T_tensorDeclaration:  11904 ns
T_forward:  95313 ns

 Loop nbr.:  57
T_tensorDeclaration:  11845 ns
T_forward:  95340 ns

 Loop nbr.:  58
T_tensorDeclaration:  11844 ns
T_forward:  94570 ns

 Loop nbr.:  59
T_tensorDeclaration:  11853 ns
T_forward:  95225 ns

 Loop nbr.:  60
T_tensorDeclaration:  11875 ns
T_forward:  94859 ns

 Loop nbr.:  61
T_tensorDeclaration:  11816 ns
T_forward:  95111 ns

 Loop nbr.:  62
T_tensorDeclaration:  11861 ns
T_forward:  95106 ns

 Loop nbr.:  63
T_tensorDeclaration:  12010 ns
T_forward:  96033 ns

 Loop nbr.:  64
T_tensorDeclaration:  11941 ns
T_forward:  97326 ns

 Loop nbr.:  65
T_tensorDeclaration:  11956 ns
T_forward:  95006 ns

 Loop nbr.:  66
T_tensorDeclaration:  11985 ns
T_forward:  122965 ns

 Loop nbr.:  67
T_tensorDeclaration:  25619 ns
T_forward:  109450 ns

 Loop nbr.:  68
T_tensorDeclaration:  11941 ns
T_forward:  95234 ns

 Loop nbr.:  69
T_tensorDeclaration:  11843 ns
T_forward:  95319 ns

 Loop nbr.:  70
T_tensorDeclaration:  11895 ns
T_forward:  95040 ns

 Loop nbr.:  71
T_tensorDeclaration:  11932 ns
T_forward:  95334 ns

 Loop nbr.:  72
T_tensorDeclaration:  11891 ns
T_forward:  95289 ns

 Loop nbr.:  73
T_tensorDeclaration:  11805 ns
T_forward:  95366 ns

 Loop nbr.:  74
T_tensorDeclaration:  11795 ns
T_forward:  95348 ns

 Loop nbr.:  75
T_tensorDeclaration:  11955 ns
T_forward:  95634 ns

 Loop nbr.:  76
T_tensorDeclaration:  11753 ns
T_forward:  94713 ns

 Loop nbr.:  77
T_tensorDeclaration:  11856 ns
T_forward:  95359 ns

 Loop nbr.:  78
T_tensorDeclaration:  11912 ns
T_forward:  94837 ns

 Loop nbr.:  79
T_tensorDeclaration:  11818 ns
T_forward:  95065 ns

 Loop nbr.:  80
T_tensorDeclaration:  11819 ns
T_forward:  94696 ns

 Loop nbr.:  81
T_tensorDeclaration:  11834 ns
T_forward:  95074 ns

 Loop nbr.:  82
T_tensorDeclaration:  11797 ns
T_forward:  95438 ns

 Loop nbr.:  83
T_tensorDeclaration:  11821 ns
T_forward:  95284 ns

 Loop nbr.:  84
T_tensorDeclaration:  11881 ns
T_forward:  94846 ns

 Loop nbr.:  85
T_tensorDeclaration:  11867 ns
T_forward:  95104 ns

 Loop nbr.:  86
T_tensorDeclaration:  11830 ns
T_forward:  94943 ns

 Loop nbr.:  87
T_tensorDeclaration:  11865 ns
T_forward:  95280 ns

 Loop nbr.:  88
T_tensorDeclaration:  11604 ns
T_forward:  94547 ns

 Loop nbr.:  89
T_tensorDeclaration:  11758 ns
T_forward:  94695 ns

 Loop nbr.:  90
T_tensorDeclaration:  11761 ns
T_forward:  94878 ns

 Loop nbr.:  91
T_tensorDeclaration:  11788 ns
T_forward:  95150 ns

 Loop nbr.:  92
T_tensorDeclaration:  11890 ns
T_forward:  95654 ns

 Loop nbr.:  93
T_tensorDeclaration:  11898 ns
T_forward:  94998 ns

 Loop nbr.:  94
T_tensorDeclaration:  11762 ns
T_forward:  95416 ns

 Loop nbr.:  95
T_tensorDeclaration:  11848 ns
T_forward:  3099025 ns

 Loop nbr.:  96
T_tensorDeclaration:  12879 ns
T_forward:  4235451 ns

 Loop nbr.:  97
T_tensorDeclaration:  12708 ns
T_forward:  4268075 ns

 Loop nbr.:  98
T_tensorDeclaration:  12960 ns
T_forward:  4278626 ns

 Loop nbr.:  99
T_tensorDeclaration:  12897 ns
T_forward:  4275762 ns

 Loop nbr.:  100
T_tensorDeclaration:  12794 ns
T_forward:  4086642 ns
20:26:49: Debugging has finished

...

The forward time is around 100us for the 94 first loop, but go to 4000us after that. In comparison, the tensor declaration time stays roughly the same.
It looks a bit like a memory issue, but I don’t see from where it comes from.
The model is used in “eval” mode and “torch::NoGradGuard no_grad;” is set.

I found this recent post which looked similar, but the reply doesn’t seem to help.

Is there anything obvious that I’m missing ? What could I do to solve this issue ?

Here is my configuration:
OS: Kubuntu
Processor: Intel® Core™ i7-8850H CPU @ 2.60GHz
GPU: GP107GLM [Quadro P1000 Mobile]
Cuda version: 9.1
libtorch version: libtorch-cxx11-abi-shared-with-deps-1.4.0+cu92

Thanks a lot, I wish you a great day !

Best regards,
Florent

Hi,

This happens because the CUDA api is asynchronous. So the first iteration, what you measure is just the time to queue work on the GPU.
After a while, that queue is full and you have to wait for stuff to be actually done before being able to enqueue more work.
You can use the equivalent of torch.cuda.synchronize() (torch::cuda::synchronize() maybe? not sure about the cpp API) to wait for all work to be finished before measuring time.

Thanks for the reply.

I found the function “cudaDeviceSynchronize();” in “cuda_runtime.h” (libs: -lcudart).

When I add it after the forward function, I get the following results:


20:46:37: Debugging starts

 Loop nbr.:  0
T_tensorDeclaration:  76694 ns
T_forward:  4744557 ns

 Loop nbr.:  1
T_tensorDeclaration:  23076 ns
T_forward:  4170827 ns

 Loop nbr.:  2
T_tensorDeclaration:  17030 ns
T_forward:  4189053 ns

 Loop nbr.:  3
T_tensorDeclaration:  15676 ns
T_forward:  4167333 ns

 Loop nbr.:  4
T_tensorDeclaration:  16317 ns
T_forward:  4172503 ns

 Loop nbr.:  5
T_tensorDeclaration:  15665 ns
T_forward:  4190116 ns

 Loop nbr.:  6
T_tensorDeclaration:  15595 ns
T_forward:  4163890 ns

 Loop nbr.:  7
T_tensorDeclaration:  15042 ns
T_forward:  4166625 ns

 Loop nbr.:  8
T_tensorDeclaration:  15177 ns
T_forward:  4180314 ns

 Loop nbr.:  9
T_tensorDeclaration:  15237 ns
T_forward:  4167773 ns

 Loop nbr.:  10
T_tensorDeclaration:  14767 ns
T_forward:  4189780 ns

 Loop nbr.:  11
T_tensorDeclaration:  15009 ns
T_forward:  4188130 ns

 Loop nbr.:  12
T_tensorDeclaration:  14792 ns
T_forward:  4169861 ns

 Loop nbr.:  13
T_tensorDeclaration:  14547 ns
T_forward:  4186025 ns

 Loop nbr.:  14
T_tensorDeclaration:  14807 ns
T_forward:  4198113 ns

 Loop nbr.:  15
T_tensorDeclaration:  14582 ns
T_forward:  4192522 ns

 Loop nbr.:  16
T_tensorDeclaration:  14601 ns
T_forward:  4183756 ns

 Loop nbr.:  17
T_tensorDeclaration:  14733 ns
T_forward:  4204688 ns

 Loop nbr.:  18
T_tensorDeclaration:  14592 ns
T_forward:  4167611 ns

 Loop nbr.:  19
T_tensorDeclaration:  14630 ns
T_forward:  4174918 ns

 Loop nbr.:  20
T_tensorDeclaration:  14659 ns
T_forward:  4156959 ns

 Loop nbr.:  21
T_tensorDeclaration:  14422 ns
T_forward:  4178084 ns

 Loop nbr.:  22
T_tensorDeclaration:  14625 ns
T_forward:  4192556 ns

 Loop nbr.:  23
T_tensorDeclaration:  14531 ns
T_forward:  4175173 ns

 Loop nbr.:  24
T_tensorDeclaration:  14580 ns
T_forward:  4166659 ns

 Loop nbr.:  25
T_tensorDeclaration:  14485 ns
T_forward:  4177613 ns

 Loop nbr.:  26
T_tensorDeclaration:  14479 ns
T_forward:  4179809 ns

 Loop nbr.:  27
T_tensorDeclaration:  14505 ns
T_forward:  4178439 ns

 Loop nbr.:  28
T_tensorDeclaration:  14496 ns
T_forward:  4184789 ns

 Loop nbr.:  29
T_tensorDeclaration:  14618 ns
T_forward:  4177106 ns

 Loop nbr.:  30
T_tensorDeclaration:  14547 ns
T_forward:  4182303 ns

 Loop nbr.:  31
T_tensorDeclaration:  14641 ns
T_forward:  4173239 ns

 Loop nbr.:  32
T_tensorDeclaration:  19249 ns
T_forward:  4176845 ns

 Loop nbr.:  33
T_tensorDeclaration:  14245 ns
T_forward:  4128455 ns

 Loop nbr.:  34
T_tensorDeclaration:  14577 ns
T_forward:  4151471 ns

 Loop nbr.:  35
T_tensorDeclaration:  14191 ns
T_forward:  4157208 ns

 Loop nbr.:  36
T_tensorDeclaration:  14156 ns
T_forward:  4157328 ns

 Loop nbr.:  37
T_tensorDeclaration:  14135 ns
T_forward:  4164124 ns

 Loop nbr.:  38
T_tensorDeclaration:  14024 ns
T_forward:  4149202 ns

 Loop nbr.:  39
T_tensorDeclaration:  14068 ns
T_forward:  4148399 ns

 Loop nbr.:  40
T_tensorDeclaration:  14095 ns
T_forward:  4149947 ns

 Loop nbr.:  41
T_tensorDeclaration:  14166 ns
T_forward:  4134731 ns

 Loop nbr.:  42
T_tensorDeclaration:  13995 ns
T_forward:  4130745 ns

 Loop nbr.:  43
T_tensorDeclaration:  14047 ns
T_forward:  4157489 ns

 Loop nbr.:  44
T_tensorDeclaration:  17812 ns
T_forward:  4145059 ns

 Loop nbr.:  45
T_tensorDeclaration:  14840 ns
T_forward:  4145376 ns

 Loop nbr.:  46
T_tensorDeclaration:  14524 ns
T_forward:  4144078 ns

 Loop nbr.:  47
T_tensorDeclaration:  14292 ns
T_forward:  4156597 ns

 Loop nbr.:  48
T_tensorDeclaration:  14604 ns
T_forward:  4130630 ns

 Loop nbr.:  49
T_tensorDeclaration:  14345 ns
T_forward:  4156479 ns

 Loop nbr.:  50
T_tensorDeclaration:  14376 ns
T_forward:  4137532 ns

 Loop nbr.:  51
T_tensorDeclaration:  14162 ns
T_forward:  4138844 ns

 Loop nbr.:  52
T_tensorDeclaration:  14233 ns
T_forward:  4138394 ns

 Loop nbr.:  53
T_tensorDeclaration:  14245 ns
T_forward:  4160527 ns

 Loop nbr.:  54
T_tensorDeclaration:  16532 ns
T_forward:  4154455 ns

 Loop nbr.:  55
T_tensorDeclaration:  15135 ns
T_forward:  4151513 ns

 Loop nbr.:  56
T_tensorDeclaration:  14605 ns
T_forward:  4159401 ns

 Loop nbr.:  57
T_tensorDeclaration:  15114 ns
T_forward:  4146790 ns

 Loop nbr.:  58
T_tensorDeclaration:  15189 ns
T_forward:  4163448 ns

 Loop nbr.:  59
T_tensorDeclaration:  14600 ns
T_forward:  4154543 ns

 Loop nbr.:  60
T_tensorDeclaration:  14391 ns
T_forward:  4152786 ns

 Loop nbr.:  61
T_tensorDeclaration:  14449 ns
T_forward:  4139992 ns

 Loop nbr.:  62
T_tensorDeclaration:  14358 ns
T_forward:  4136887 ns

 Loop nbr.:  63
T_tensorDeclaration:  14809 ns
T_forward:  4138758 ns

 Loop nbr.:  64
T_tensorDeclaration:  14358 ns
T_forward:  4128357 ns

 Loop nbr.:  65
T_tensorDeclaration:  14318 ns
T_forward:  4138889 ns

 Loop nbr.:  66
T_tensorDeclaration:  14135 ns
T_forward:  4138168 ns

 Loop nbr.:  67
T_tensorDeclaration:  14288 ns
T_forward:  4158496 ns

 Loop nbr.:  68
T_tensorDeclaration:  14304 ns
T_forward:  4144945 ns

 Loop nbr.:  69
T_tensorDeclaration:  14286 ns
T_forward:  4137468 ns

 Loop nbr.:  70
T_tensorDeclaration:  14182 ns
T_forward:  4163196 ns

 Loop nbr.:  71
T_tensorDeclaration:  15629 ns
T_forward:  4154694 ns

 Loop nbr.:  72
T_tensorDeclaration:  14497 ns
T_forward:  4119437 ns

 Loop nbr.:  73
T_tensorDeclaration:  14230 ns
T_forward:  4150738 ns

 Loop nbr.:  74
T_tensorDeclaration:  14308 ns
T_forward:  4382309 ns

 Loop nbr.:  75
T_tensorDeclaration:  15805 ns
T_forward:  4121531 ns

 Loop nbr.:  76
T_tensorDeclaration:  15074 ns
T_forward:  4116742 ns

 Loop nbr.:  77
T_tensorDeclaration:  14275 ns
T_forward:  4126637 ns

 Loop nbr.:  78
T_tensorDeclaration:  14278 ns
T_forward:  4127533 ns

 Loop nbr.:  79
T_tensorDeclaration:  14324 ns
T_forward:  4128542 ns

 Loop nbr.:  80
T_tensorDeclaration:  14327 ns
T_forward:  4118499 ns

 Loop nbr.:  81
T_tensorDeclaration:  14270 ns
T_forward:  4138547 ns

 Loop nbr.:  82
T_tensorDeclaration:  14196 ns
T_forward:  4130813 ns

 Loop nbr.:  83
T_tensorDeclaration:  18730 ns
T_forward:  4100290 ns

 Loop nbr.:  84
T_tensorDeclaration:  14288 ns
T_forward:  4127821 ns

 Loop nbr.:  85
T_tensorDeclaration:  14240 ns
T_forward:  4113148 ns

 Loop nbr.:  86
T_tensorDeclaration:  14323 ns
T_forward:  4118545 ns

 Loop nbr.:  87
T_tensorDeclaration:  14405 ns
T_forward:  4117402 ns

 Loop nbr.:  88
T_tensorDeclaration:  14367 ns
T_forward:  4108369 ns

 Loop nbr.:  89
T_tensorDeclaration:  14503 ns
T_forward:  4139753 ns

 Loop nbr.:  90
T_tensorDeclaration:  14343 ns
T_forward:  4128485 ns

 Loop nbr.:  91
T_tensorDeclaration:  14376 ns
T_forward:  4128973 ns

 Loop nbr.:  92
T_tensorDeclaration:  14353 ns
T_forward:  4126810 ns

 Loop nbr.:  93
T_tensorDeclaration:  14201 ns
T_forward:  4123998 ns

 Loop nbr.:  94
T_tensorDeclaration:  15228 ns
T_forward:  4122794 ns

 Loop nbr.:  95
T_tensorDeclaration:  14555 ns
T_forward:  4106747 ns

 Loop nbr.:  96
T_tensorDeclaration:  14585 ns
T_forward:  4130666 ns

 Loop nbr.:  97
T_tensorDeclaration:  14393 ns
T_forward:  4126877 ns

 Loop nbr.:  98
T_tensorDeclaration:  14084 ns
T_forward:  4134650 ns

 Loop nbr.:  99
T_tensorDeclaration:  14464 ns
T_forward:  4138943 ns

 Loop nbr.:  100
T_tensorDeclaration:  14557 ns
T_forward:  4120151 ns
20:46:44: Debugging has finished

Does it means that the real forward time of the network is around 4ms ? It seems quite slow for a small network like this, isn’t it ?
When the data are feed by batch in the training, it’s going much faster.

Indeed, after more test, it seems to be the real computation time.

So to summarize, the timer measures the time it takes to go through the model forward function, which lists the computation tasks from the model in the queue of the GPU. When the queue is full, the program wait for space in the queue, meaning that the timer measures now also the time needed for the queue to advance (i.e. that computation tasks listed in the queue are done). This final time measured is then much closer to the real computation time.

I post the final code example in case it could help someone. I didn’t find a lot of example for C++ and Qt, it might be a good start for other.

main.cpp:

#include <torch/torch.h>
#include <torch/script.h>
#include <QDebug>
#include <QElapsedTimer>

#include <cuda_runtime.h>

// Simple 5 linear layers network
struct NetImpl : torch::nn::Module {

    torch::nn::Linear linear_1;
    torch::nn::Linear linear_2;
    torch::nn::Linear linear_3;
    torch::nn::Linear linear_4;
    torch::nn::Linear linear_5;

    NetImpl(std::vector<int64_t> linear_dim_in, std::vector<int64_t> linear_dim_out)
      : linear_1(linear_dim_in[0], linear_dim_out[0]),
        linear_2(linear_dim_in[1], linear_dim_out[1]),
        linear_3(linear_dim_in[2], linear_dim_out[2]),
        linear_4(linear_dim_in[3], linear_dim_out[3]),
        linear_5(linear_dim_in[4], linear_dim_out[4])
    {
        register_module("linear_1", linear_1);
        register_module("linear_2", linear_2);
        register_module("linear_3", linear_3);
        register_module("linear_4", linear_4);
        register_module("linear_5", linear_5);
    }

    torch::Tensor forward(torch::Tensor x)
    {
        x = torch::flatten(x, 1, -1); // Flatten

        x = linear_1->forward(x);
        x = linear_2->forward(x);
        x = linear_3->forward(x);
        x = linear_4->forward(x);
        x = linear_5->forward(x);

        return x;
    }
};
TORCH_MODULE(Net); // creates module holder for NetImpl



int main()
{

    // Detect CUDA device
    torch::Device device("cpu");
    if (torch::cuda::is_available())
    {
        device = torch::Device("cuda:0");
    }

    // Declare network build variables
    std::vector<int64_t> linear_dim_in;
    linear_dim_in.push_back(72960);
    linear_dim_in.push_back(1024);
    linear_dim_in.push_back(512);
    linear_dim_in.push_back(256);
    linear_dim_in.push_back(128);

    std::vector<int64_t> linear_dim_out;
    linear_dim_out.push_back(1024);
    linear_dim_out.push_back(512);
    linear_dim_out.push_back(256);
    linear_dim_out.push_back(128);
    linear_dim_out.push_back(3);

    // Create network
    Net net(linear_dim_in, linear_dim_out);

    // Set the model in eval mode
    net->eval();

    // Set the device in GPU
    net->to(device);


    // Declare a Timer and the loop number
    QElapsedTimer timer;
    int loop_nbr = 100;

    {
        // Disable grads computation. If not, grad values accumulate in GPU memory.
        torch::NoGradGuard no_grad;

        // For loop_nbr times
        for (int i = 0; i <= loop_nbr; i++)
        {
            qDebug() << "\n" << "Loop nbr.: " << i;

            // Start timer
            timer.restart();

            // Declare a random input
            torch::Tensor X_batch = torch::rand({1,1,240,304}, device);

            // Show the elapsed time
            qDebug() << "T_tensorDeclaration: " << timer.nsecsElapsed() << "ns";

            // Run the model on the input data.
            torch::Tensor prediction = net->forward(X_batch);

            // Wait for the end of GPU computation
            cudaDeviceSynchronize();

            // Show the elapsed time
            qDebug() << "T_forward: " << timer.nsecsElapsed() << "ns";
        }
    }

    return 0;
}

project.pro:

QT       += core gui

greaterThan(QT_MAJOR_VERSION, 4): QT += widgets

CONFIG += c++11

CONFIG += no_keywords

# The following define makes your compiler emit warnings if you use
# any Qt feature that has been marked deprecated (the exact warnings
# depend on your compiler). Please consult the documentation of the
# deprecated API in order to know how to port your code away from it.
DEFINES += QT_DEPRECATED_WARNINGS

SOURCES += main.cpp

INCLUDEPATH += $$PWD/../../libtorch/include
DEPENDPATH += $$PWD/../../libtorch/include

INCLUDEPATH += $$PWD/../../libtorch/include/torch/csrc/api/include
DEPENDPATH += $$PWD/../../libtorch/include/torch/csrc/api/include

LIBS += -L$$PWD/../../libtorch/lib/ -ltorch -lc10 -lcudart
1 Like