Inconsistent model inference Time

Hi all. I am a graduate student working on a project about Model inference time optimization.
I am using this very simple code snippet to benchmark the inference time of resenet50:

    for i, (batch_input, batch_target) in enumerate(data_loader):
        batch_input_var = batch_input.to(device)
        torch.cuda.synchronize()


        # time.sleep(0.2)
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        start.record()
        with torch.no_grad():
            output = net(batch_input_var)
        end.record()
        torch.cuda.synchronize()
        print(start.elapsed_time(end))

It gives me 10.5ms consistently.

However, if I uncomment time.sleep(0.2), or do something else which takes around 0.05s to 0.2s, the results start to be highly inconsistent, ranging from 12.7ms to 26.1ms. If I set the sleep time to say 2 seconds, the results become consistent again, around 15ms.

Please I am wondering if you have any thoughts on this?

Thanks!

any hope to find help for this?

I haven’t seen this issue before and also cannot reproduce it using a small example model and random input data.
Could you share a code snippet to reproduce this issue and also share your setup information?

Hi @ptrblck Thanks a lot for your response!

I have a minimal script to show this problem at the end.

I am using Python 3.6.9, Pytorch version 1.7.1, I’m using RTX 2080, CUDA version is 11.0, driver version 450.80.02

Script:

import torchvision.models as models
import torch
import numpy as np
import time


if __name__ == '__main__':
    device=0
    net = models.resnet50(pretrained=True).to(device)
    net.eval()

    data = np.random.uniform(0, 1, (200, 3, 480, 854))
    label = np.random.randint(0, 100, (200,))
    tensor_imgs = torch.Tensor(data.astype(np.float32))
    tensor_labels = torch.Tensor(label.astype(np.float32))
    dataset = torch.utils.data.dataset.TensorDataset(tensor_imgs, tensor_labels)
    data_loader = torch.utils.data.dataloader.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=10)

    for i, (batch_input, batch_target) in enumerate(data_loader):
        if i >= 100:
            break
        batch_input_var = batch_input.to(device)
        output = net(batch_input_var)
    print('Warm-up finished')


    for i, (batch_input, batch_target) in enumerate(data_loader):
        batch_input_var = batch_input.to(device)
        torch.cuda.synchronize()

        time.sleep(0.2)
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        start.record()
        with torch.no_grad():
            output = net(batch_input_var)
        end.record()
        torch.cuda.synchronize()
        print(start.elapsed_time(end))

Thanks I can reproduce the different timing results using time.sleep.
However, if I add a real workload, e.g.:

def fun2():
    for _ in range(200):
        np.matmul(np.random.randn(1024, 1024), np.random.randn(1024, 1024))

if __name__ == '__main__':
    device=0
    ...
    for i, (batch_input, batch_target) in enumerate(data_loader):
        batch_input_var = batch_input.to(device)
        torch.cuda.synchronize()

        #time.sleep(0.2)
        fun2()
        ...

I get valid results again (although the code overall will have a large delay), so I assume time might distort the profiling somehow.
I’ll ask around, since this is new to me.

@ptrblck Thanks!

The motivation I ask this is because I am designing a real time scheduler for model inference.

When there’s no workload in the system (system idle waiting for requests) and then an inference request comes and starts to be processed on GPU, it gives the unstable inference time, just like the situation of using time.sleep(). When GPU is busy processing some data and then immediately switches to new data, the inference time of these new data is quite normal.

Also I tried time.perf_counter to replace torch.event, it gives similar results.

I hope this provides more context information for this problem.

Thank you very much!

The “wake up” time of the GPU might be expected and I don’t know, if you can avoid it, since the GPU would reduce it’s Perf to use less power (as seen via nvidia-smi) if no workload is currently executed.
What I don’t know, is why time.sleep interferes with the profiling, but I’m also not an expert in Python mutliprocessing and don’t know if the main process is just suspended for the specified amount of time or if “more” is happening under the hood.

@ptrblck Thanks for the direction!
I will check if there’s a way to keep GPU alive.

You can check the setup of your devices via nvidia-smi -q -d PERFORMANCE and should see an active Clocks Throttle Reason as Active (most likely Idle).
Using nvidia-settings you could try to force your device to run at P0 at all times, if that’s what you desire.

Hi @ptrblck I tried the command nvidia-settings -a '[gpu:0]/GPUPowerMizerMode=1'. It changes my Perf to P0 in nvidia-smi. But whenever I launch the script it falls back to P2 (throttle reasons are all inactive), and the inference time is still jumping around. When the script is finished it goes back to P0. Do you know if there’s a way to keep it at P0 please?

Hi @milesyang, do you solve this problem now?
I test the code on RTX3090 (desktop) and RTX 2080 max-q (laptop). I found the Perf is always p2 for RTX3090 and the ineference time keeps the same however the sleep time is. But for RTX 2080 max-q, the Perf jumps with the sleep time. If sleep time is less than 50ms, the Perf is always P0 and the inference time is normal. But if sleep time is 500ms, the Perf jumps, maybe P0, P3 or P5, and the inference time jumps either. So I assume P2 is a normal power state for RTX3090 and P0 for RTX 2080 max-q. But since RTX 2080 max-q is a laptop GPU, the Perf tends to jump.