Bizzare extra time consumption in Pytorch (GPU) 1.1.0 / 1.2.0?

Hi, everyone,

I have been struggling with an extra and confusing time problem which I encountered for twice:

First time:
I have two model running simultaneously with a joint loss but it is quite slow(up to 0.4s~0.5s+) for each step. So I was trying to improve my efficiency through checking some time spot using time.time(). And I found out The first small model is just fine(like 0.002s). However when I was checking the second model. At the beginning of my second model’s forwarding, Something weird just happen:

def forward(feature, coords, permutation):
    """
    feature : [B, d, N, 1]
    coords: [B, N ,3]
    """
    t1 = time.time()
    coords: [:, permutation]  
    t2 =  time.time()
    feature = feature[:, :, permutation]   
    t3 = time.time()
    print(t2 - t1) # running for 0.24s
    print(t3 - t2)  # 0.00...  very short anyway
def forward(feature, coords, permutation):
    """
    feature : [B, d, N, 1]
    coords: [B, N ,3]
    """
    t1 = time.time()
    feature = feature[:, :, permutation]   
    t2 =  time.time()
    coords: [:, permutation]  
    t3 = time.time()
    print(t2 - t1) # running for 0.24s
    print(t3 - t2)  # 0.00...  very small anyway

It seems that the what really counts is its sequence instead of what the code truly does.
Second:

I was debugging my code using time.time() like before, and the time consumption was becoming even weirder. In my second model, I found a very time consuming line : a[b] , especially when there is an action that put b.to(device), it is just a simple slice operation but it took 0.4s+, 95%+ of the whole model. So I change a way to get the slice, removing the .to(device) part it decreased to several ms like it is solved. However, unfortunately, I found the total time still remain unchanged , what is more, I found another line of code became suddenly very time consuming with no reason just like a ghost with absolutely no changes done to it. So I did rewrite the sentence again. But it is vain, because the “time consuming” part just moved to next part like a virus.

However I found a common feature between them. Every “ghost time consuming” part will be related to (to(device) / device / … etc.) , and when you fix one of them, another one will just become the “ghost time consuming” part again.

Is there a preparation stage of GPU like it can’t be removed ? If yes, why dose every step needs this ? Literally every step in my iteration has a “ghost” time spot.

Yours sincerely,
ZD

CUDA operations are executed asynchronously, so that you would need to synchronize the code before starting and stopping the timers via torch.cuda.synchronize().
Since you are not synchronizing the code at the moment, the returned times will accumulate all previously enqueued kernels until a synchronization is forced.

Could you add the synchronizations and rerun your profiling, please?

Thanks for your reply,

I tried using this before .time() like this

torch.cuda.synchronize()
t1 = time.time()
...some operations here...
torch.cuda.synchronize()
t2 = time.time()
print("Time consumption:", )

I add torch.cuda.synchronize() before every time.time() and I have got quite a different result from the situation using only time.time() without torch.cuda.synchronize() . I finally localized the time consuming part and found the code below actually takes over 0.4s+. From what you told me, could I assume that torch.cuda.synchronize() could help me get absolute time of these operations ?
If is, then my question became:

# inputs shape: [1, 128, 120, 120]
torch.max(inputs.view(b, -1), dim=1)[0]  # very time consuming ~0.4s +
torch.max(inputs) # very quick ~ 0.0003s

Why is there the difference ? I couldn’t see the mechanism behind it. From what I see it shouldn’t be like this

Synchronizations will make sure the timers grab the runtime of all operations between the sync points.
Otherwise your profiling would run into the “ghost” operations you’ve mentioned, i.e. you would most of the time profile the kernel launch only and would accumulate the complete runtime of all scheduled kernels into the next blocking operation.

The first operation will return the max values in dim=1, so the output would contain the values and indices in the shape [b], while the second call treats the inputs tensor as a flat tensor and will only return the max. value as a scalar.

That being said, I cannot reproduce the large timing difference using this code:

x = torch.randn(1024, 1024*128, device='cuda')

nb_iters = 1000

torch.cuda.synchronize()
t0 = time.time()
for _ in range(nb_iters):
    out = torch.max(x, dim=1)
torch.cuda.synchronize()
t1 = time.time()
print((t1 - t0)/nb_iters)
> 0.012184044361114503

torch.cuda.synchronize()
t0 = time.time()
for _ in range(nb_iters):
    out = torch.max(x)
torch.cuda.synchronize()
t1 = time.time()
print((t1 - t0)/nb_iters)
> 0.012206459522247314

Thank you for your answer! It’s much clearer now.

After I captured the “ghost time” by using torch.cuda.synchronize() namely the code I showed above, I just change torch.max(torch.max(inputs.view(b, -1), dim=1)[0]) into torch.max(inputs) and the “ghost” time just vanished from my whole project after a long time of debugging and putting that off to next stage. However I don’t understand the reason behind it. Maybe it’s due to the computation graph or something else to which I didn’t pay much attention. Because when I do these two separately outside my project in a new file, I got the same result as you produced.

Anyway, the problem has gone even though I haven’t figured out why it is that way.

If you don’t synchronize the code before starting and stopping the timer, you won’t capture the real runtime of the CUDA operations, since they are executed asynchronously in the background.
Basically your Python script will call the CUDA operation, which will be scheduled on the GPU and executed, while the Python script can just execute the next lines of code.
E.g. in this code snippet:

t0 = time.time()
y = x.sum() # x is a CUDATensor
t1 = time.time()

you would only profile the Python overhead of calling the summation CUDA kernel, not the execution itself.

The code would block automatically, e.g. if you are trying to use a result tensor:

# async execution
output = model(x)
loss = criterion(output, target)
loss.backward()
# blocking
print(loss.cpu().item())

These synchronization points might be in different parts of your code, so that your profilng result would accumulate all previous CUDA operation into the first blocking operation, which will yield wrong results.