What caused the large difference in execution time between these two forward passes?

for i in range(263):
        states = torch.rand(128, 4, 20, 20, device = device)
        # time.sleep(1)
        start_time = time.time()
        policy, value, action_value = model(states)
        end_time = time.time()
        total_time = end_time - start_time
        print(f"time cost:{total_time:.4f} second")

it print as:
time cost:0.0003 second
time cost:0.0003 second
time cost:0.0003 second

But when I uncomment the time.sleep(1), it becomes slow:

time cost:0.0008 second
time cost:0.0009 second
time cost:0.0009 second

In my c++ program with libtorch, which has 200+ lines so I don’t put it here, the difference is even larger

std::shared_ptr<c10::IValue> results_tensor = std::make_shared<c10::IValue>();
std::shared_ptr<torch::jit::script::Module> model;
while(true)
{
    auto start = std::chrono::high_resolution_clock::now();

    *results_tensor = nn.model->forward({input});

    auto end = std::chrono::high_resolution_clock::now();
    std::chrono::duration<double> duration = end - start;
    {
        std::lock_guard<std::mutex> lock(cout_mutex);
        std::cout << "Inference time for forward: " << duration.count() << " seconds." << std::endl;
    }
}

it only costs 0.0001 seconds each time
without while(true), run normally, it costs 0.01 seconds, that’s 100x slower!

Assuming you are using a GPU, then your profiling is wrong since CUDA operations are executed asynchronously. You would need to synchronize the code before starting and stopping the host timers.

I write a minimal reproducible example

import torch
import torch.nn as nn
import time

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        
        self.shared_layers = nn.Sequential(
            nn.Conv2d(4, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(1)
        )
        self.policy_head = nn.Linear(64, 400)
        self.value_head = nn.Linear(64, 1)
        self.action_value_head = nn.Linear(64, 400)

    def forward(self, x):
        x = self.shared_layers(x)
        x = x.view(x.size(0), -1)
        policy = self.policy_head(x)
        value = self.value_head(x)
        action_values = self.action_value_head(x)
        return policy, value, action_values

if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("using device: ", device)
    model = Net().to(device)
    model.eval()

    print("\nnow no sleep:")
    for i in range(10):
        states = torch.rand(128, 4, 20, 20, device = device)
        #no time.sleep
        start_time = time.time()
        start_event = torch.cuda.Event(enable_timing=True)
        end_event = torch.cuda.Event(enable_timing=True)
        start_event.record()

        # forward
        policy, value, action_value = model(states)

        end_event.record()
        torch.cuda.synchronize()
        end_time = time.time()
        total_time = start_event.elapsed_time(end_event) / 1000  # 转换为秒
        print(f"time cost counted by gpu: {total_time:.4f} second")
        total_time = end_time - start_time
        print(f"time cost counted by cpu:{total_time:.4f} second")


    print("\nnow sleep 1:")
    for i in range(10):
        states = torch.rand(128, 4, 20, 20, device = device)
        time.sleep(1)
        start_time = time.time()
        start_event = torch.cuda.Event(enable_timing=True)
        end_event = torch.cuda.Event(enable_timing=True)
        start_event.record()

        # forward
        policy, value, action_value = model(states)

        end_event.record()
        torch.cuda.synchronize()
        end_time = time.time()
        total_time = start_event.elapsed_time(end_event) / 1000  # 转换为秒
        print(f"time cost counted by gpu: {total_time:.4f} second")
        total_time = end_time - start_time
        print(f"time cost counted by cpu:{total_time:.4f} second")



    print("\nnow sleep 10:")
    for i in range(10):
        states = torch.rand(128, 4, 20, 20, device = device)
        time.sleep(10)
        start_time = time.time()
        start_event = torch.cuda.Event(enable_timing=True)
        end_event = torch.cuda.Event(enable_timing=True)
        start_event.record()

        # forward
        policy, value, action_value = model(states)

        end_event.record()
        torch.cuda.synchronize()
        end_time = time.time()
        total_time = start_event.elapsed_time(end_event) / 1000  # 转换为秒
        print(f"time cost counted by gpu: {total_time:.4f} second")
        total_time = end_time - start_time
        print(f"time cost counted by cpu:{total_time:.4f} second")



    print("\nnow sleep 60:")
    for i in range(10):
        states = torch.rand(128, 4, 20, 20, device = device)
        time.sleep(60)
        start_time = time.time()
        start_event = torch.cuda.Event(enable_timing=True)
        end_event = torch.cuda.Event(enable_timing=True)
        start_event.record()

        # forward
        policy, value, action_value = model(states)

        end_event.record()
        torch.cuda.synchronize()
        end_time = time.time()
        total_time = start_event.elapsed_time(end_event) / 1000  # 转换为秒
        print(f"time cost counted by gpu: {total_time:.4f} second")
        total_time = end_time - start_time
        print(f"time cost counted by cpu:{total_time:.4f} second")

The result of running this minimal.py on my computer is:

using device:  cuda

now no sleep:
time cost counted by gpu: 0.2766 second
time cost counted by cpu:0.2768 second
time cost counted by gpu: 0.0012 second
time cost counted by cpu:0.0013 second
time cost counted by gpu: 0.0004 second
time cost counted by cpu:0.0004 second
time cost counted by gpu: 0.0004 second
time cost counted by cpu:0.0004 second
time cost counted by gpu: 0.0003 second
time cost counted by cpu:0.0004 second
time cost counted by gpu: 0.0003 second
time cost counted by cpu:0.0004 second
time cost counted by gpu: 0.0004 second
time cost counted by cpu:0.0004 second
time cost counted by gpu: 0.0003 second
time cost counted by cpu:0.0004 second
time cost counted by gpu: 0.0004 second
time cost counted by cpu:0.0004 second
time cost counted by gpu: 0.0003 second
time cost counted by cpu:0.0004 second

now sleep 1:
time cost counted by gpu: 0.0005 second
time cost counted by cpu:0.0007 second
time cost counted by gpu: 0.0006 second
time cost counted by cpu:0.0007 second
time cost counted by gpu: 0.0006 second
time cost counted by cpu:0.0007 second
time cost counted by gpu: 0.0006 second
time cost counted by cpu:0.0007 second
time cost counted by gpu: 0.0008 second
time cost counted by cpu:0.0009 second
time cost counted by gpu: 0.0007 second
time cost counted by cpu:0.0008 second
time cost counted by gpu: 0.0007 second
time cost counted by cpu:0.0007 second
time cost counted by gpu: 0.0006 second
time cost counted by cpu:0.0008 second
time cost counted by gpu: 0.0008 second
time cost counted by cpu:0.0008 second
time cost counted by gpu: 0.0007 second
time cost counted by cpu:0.0008 second

now sleep 10:
time cost counted by gpu: 0.0007 second
time cost counted by cpu:0.0009 second
time cost counted by gpu: 0.0009 second
time cost counted by cpu:0.0011 second
time cost counted by gpu: 0.0009 second
time cost counted by cpu:0.0011 second
time cost counted by gpu: 0.0010 second
time cost counted by cpu:0.0014 second
time cost counted by gpu: 0.0009 second
time cost counted by cpu:0.0011 second
time cost counted by gpu: 0.0008 second
time cost counted by cpu:0.0010 second
time cost counted by gpu: 0.0016 second
time cost counted by cpu:0.0019 second
time cost counted by gpu: 0.0009 second
time cost counted by cpu:0.0011 second
time cost counted by gpu: 0.0007 second
time cost counted by cpu:0.0010 second
time cost counted by gpu: 0.0008 second
time cost counted by cpu:0.0010 second

now sleep 60:
time cost counted by gpu: 0.0014 second
time cost counted by cpu:0.0101 second
time cost counted by gpu: 0.0013 second
time cost counted by cpu:0.0142 second
time cost counted by gpu: 0.0015 second
time cost counted by cpu:0.0120 second
time cost counted by gpu: 0.0010 second
time cost counted by cpu:0.0114 second
time cost counted by gpu: 0.0013 second
time cost counted by cpu:0.0152 second
time cost counted by gpu: 0.0014 second
time cost counted by cpu:0.0117 second
time cost counted by gpu: 0.0015 second
time cost counted by cpu:0.0119 second
time cost counted by gpu: 0.0011 second
time cost counted by cpu:0.0154 second
time cost counted by gpu: 0.0035 second
time cost counted by cpu:0.0264 second
time cost counted by gpu: 0.0011 second
time cost counted by cpu:0.0115 second

From the results, as the time of time.sleep increases, the time for forward also increases

minimal reproducible c++ example:

#include <torch/torch.h>
#include <iostream>
#include <cuda_runtime.h>

struct Net : public torch::nn::Module {
    // Define the network structure
    torch::nn::Sequential shared_layers{
        torch::nn::Conv2d(torch::nn::Conv2dOptions(4, 32, 3).padding(1)),
        torch::nn::ReLU(),
        torch::nn::Conv2d(torch::nn::Conv2dOptions(32, 64, 3).padding(1)),
        torch::nn::ReLU(),
        torch::nn::AdaptiveAvgPool2d(1)
    };

    torch::nn::Linear policy_head{64, 400};
    torch::nn::Linear value_head{64, 1};
    torch::nn::Linear action_value_head{64, 400};

    // Constructor: Initialize layers
    Net() {
        // Register the layers (registering is mandatory)
        register_module("shared_layers", shared_layers);
        register_module("policy_head", policy_head);
        register_module("value_head", value_head);
        register_module("action_value_head", action_value_head);
    }

    // Forward pass function
    std::tuple<at::Tensor, at::Tensor, at::Tensor> forward(at::Tensor x) {
        x = shared_layers->forward(x);  // Pass through shared layers
        x = x.view({x.size(0), -1});     // Flatten output

        at::Tensor policy = policy_head->forward(x);
        at::Tensor value = value_head->forward(x);
        at::Tensor action_values = action_value_head->forward(x);

        return std::make_tuple(policy, value, action_values);
    }
};

int main() {
    // Initialize the network
    Net net;
    
    net.to(torch::kCUDA);

    // Create a random input tensor (batch size 1, 4 channels, 20x20 size)
    at::Tensor input_tensor = torch::rand({128, 4, 20, 20}, torch::kCUDA);

    // Pass the network through the input data and get the output
    for(int i = 0; i < 10; i++)
    {
        std::this_thread::sleep_for(std::chrono::seconds(60)); // pretend some calculations are being performed
        // forward pass I
        {
            auto cpu_start = std::chrono::high_resolution_clock::now();

            cudaEvent_t start, stop;
            cudaEventCreate(&start);
            cudaEventCreate(&stop);

            cudaEventRecord(start, 0);  // Start timing

            auto output = net.forward(input_tensor);

            cudaEventRecord(stop, 0);  // End timing

            cudaEventSynchronize(stop);  // Wait for GPU to finish all operations
            float elapsedTime;
            cudaEventElapsedTime(&elapsedTime, start, stop);  // Calculate time

            std::cout << "Inference time I for forward counted by GPU: " << elapsedTime/1000.0 << " seconds." << std::endl;

            auto cpu_end = std::chrono::high_resolution_clock::now();
            std::chrono::duration<double> duration = cpu_end - cpu_start;

            std::cout << "Inference time I for forward counted by CPU: " << duration.count() << " seconds." << std::endl;
        }
        // forward pass II
        {
            auto cpu_start = std::chrono::high_resolution_clock::now();

            cudaEvent_t start, stop;
            cudaEventCreate(&start);
            cudaEventCreate(&stop);

            cudaEventRecord(start, 0);  // Start timing

            auto output = net.forward(input_tensor);

            cudaEventRecord(stop, 0);  // End timing

            cudaEventSynchronize(stop);  // Wait for GPU to finish all operations
            float elapsedTime;
            cudaEventElapsedTime(&elapsedTime, start, stop);  // Calculate time

            std::cout << "Inference time II for forward counted by GPU: " << elapsedTime/1000.0 << " seconds." << std::endl;

            auto cpu_end = std::chrono::high_resolution_clock::now();
            std::chrono::duration<double> duration = cpu_end - cpu_start;

            std::cout << "Inference time II for forward counted by CPU: " << duration.count() << " seconds." << std::endl << std::endl;
        }
    }
    
    return 0;
}

result is:

Inference time I for forward counted by GPU: 0.275764 seconds.
Inference time I for forward counted by CPU: 0.286466 seconds.
Inference time II for forward counted by GPU: 0.000920576 seconds.
Inference time II for forward counted by CPU: 0.000943526 seconds.

Inference time I for forward counted by GPU: 0.000538624 seconds.
Inference time I for forward counted by CPU: 0.0112382 seconds.
Inference time II for forward counted by GPU: 0.00026624 seconds.
Inference time II for forward counted by CPU: 0.000288655 seconds.

Inference time I for forward counted by GPU: 0.00054272 seconds.
Inference time I for forward counted by CPU: 0.0106249 seconds.
Inference time II for forward counted by GPU: 0.000239616 seconds.
Inference time II for forward counted by CPU: 0.000260844 seconds.

Inference time I for forward counted by GPU: 0.000884736 seconds.
Inference time I for forward counted by CPU: 0.0111457 seconds.
Inference time II for forward counted by GPU: 0.000375808 seconds.
Inference time II for forward counted by CPU: 0.000403737 seconds.

Inference time I for forward counted by GPU: 0.000914432 seconds.
Inference time I for forward counted by CPU: 0.0109758 seconds.
Inference time II for forward counted by GPU: 0.000355328 seconds.
Inference time II for forward counted by CPU: 0.000394496 seconds.

Inference time I for forward counted by GPU: 0.0132884 seconds.
Inference time I for forward counted by CPU: 0.0241791 seconds.
Inference time II for forward counted by GPU: 0.00087552 seconds.
Inference time II for forward counted by CPU: 0.000916996 seconds.

Inference time I for forward counted by GPU: 0.00104243 seconds.
Inference time I for forward counted by CPU: 0.0144651 seconds.
Inference time II for forward counted by GPU: 0.000500736 seconds.
Inference time II for forward counted by CPU: 0.000539569 seconds.

Inference time I for forward counted by GPU: 0.0010967 seconds.
Inference time I for forward counted by CPU: 0.0111124 seconds.
Inference time II for forward counted by GPU: 0.000514048 seconds.
Inference time II for forward counted by CPU: 0.000548997 seconds.

Inference time I for forward counted by GPU: 0.0008704 seconds.
Inference time I for forward counted by CPU: 0.0117548 seconds.
Inference time II for forward counted by GPU: 0.000499712 seconds.
Inference time II for forward counted by CPU: 0.000536338 seconds.

Inference time I for forward counted by GPU: 0.000856064 seconds.
Inference time I for forward counted by CPU: 0.0109821 seconds.
Inference time II for forward counted by GPU: 0.000478208 seconds.
Inference time II for forward counted by CPU: 0.00051494 seconds.

the second time is always much faster than the first time. I don’t like it. It seems it clears some cache even I do nothing but sleep ? I don’t have to sleep but, in the real environment, I have to collect data before forward pass(cost 1e-5 second), and that make the forward pass 0.01s+ too, which is unacceptable.