Tensor.item() takes a lot of running time

Hi,

I would guess your problem is the same as https://discuss.pytorch.org/t/checking-tensor-is-all-0s-tensor-sum-data-0-0-extremely-slow/16613.
Set the CUDA_LAUNCH_BLOCKING=1 environment variable before runnning your script in a profiler to avoid this behaviour :slight_smile:

1 Like

Hi,

I am facing the same issue. In my case, CUDA_LAUNCH_BLOCKING=1 reduces the time taken by .item() but increases the forward+backward pass time by an equal amount, so total training time remains same.

I also came across Synchronization slow down caused by .item() which is not caused by .data[0], which suggests that using .data[0] instead of .item() is faster, but in my case .data[0] does not take time but it increases the time taken by data loading. So again, total training time remains same.

Any other suggestions/fixes that I can try?

EDIT: For my example. num_workers = 8, batch_size = 32. Modifying num_workers only impacts the data loading time, as expected.

EDIT 2: The increase in data loading time when I use .data[0] instead of .item() is occurring when I move images to GPU (images.cuda()) after dataloader returns a tensor containing 32 images.

1 Like

Hi,

The point above is that item() looks like it’s taking a lot of time because it causes syncronization of your gpu.
But the item call itself is not what takes time, it’s the rest of the operations that are running on the gpu.
That’s what you see when using CUDA_LAUNCH_BLOCKING=1 where you force each operation to be synchronous and thus nothing is left to be done when you call item and it executes quickly.
The behaviour of .data[0] that you see is because it delays sync even more (somewhere in the dataloader).

As you would expect the total runtime is always the same as these operations don’t actually take any time, they just change where and how the cuda sync happens.
Note that if you want to profile runtime of the cuda ops, you want to set CUDA_LAUNCH_BLOCKING=1 so that you measure each operation runtime, not the sync points at the end.

3 Likes

Thanks for the explanation. So is there any alternative that I can try to reduce delay due to cuda sync?

In your case, it seems that there is no cost for them: as you said total runtime with CUDA_LAUNCH_BLOCKING=1 is the same as without, sothe async does not buy you a significant amount of time. So the place where the sync happens (everywhere, at the .item() or in the dataloader) will not impact the total runtime.

1 Like

did some profiling using pytorch profiler (not an expert), and it seems that indexing a tensor is faster than using item. the latter requires cuda + cpu operations while the former only cpu.

profiling code: profiler.py

import torch
from torch.profiler import profile, record_function, ProfilerActivity


def function():
    seed = 0
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    z = torch.rand(1000, 1000).to(DEVICE)
    zs = z.sum().view(1, )
    torch.cuda.synchronize()

    with profile(activities=[ProfilerActivity.CPU,
                             ProfilerActivity.CUDA],
                 record_shapes=True) as prof:
        with record_function("compute_th"):
            zs.item()

    print(prof.key_averages().table(sort_by="cuda_time_total",
                                    row_limit=10))
    with profile(activities=[ProfilerActivity.CPU,
                             ProfilerActivity.CUDA],
                 record_shapes=True) as prof:
        with record_function("compute_th"):
            zs[0]

    print(prof.key_averages().table(sort_by="cuda_time_total",
                                    row_limit=10))


if __name__ == '__main__':
    cuda = "1"
    DEVICE = torch.device(
        "cuda:{}".format(cuda) if torch.cuda.is_available() else "cpu")
    function()
$ python profileer.py 
------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                          compute_th        87.05%       1.708ms        94.34%       1.851ms       1.851ms       0.000us         0.00%       2.000us       2.000us             1  
                          aten::item         1.17%      23.000us         7.14%     140.000us     140.000us       0.000us         0.00%       2.000us       2.000us             1  
           aten::_local_scalar_dense         2.96%      58.000us         5.96%     117.000us     117.000us       2.000us       100.00%       2.000us       2.000us             1  
    Memcpy DtoH (Device -> Pageable)         0.00%       0.000us         0.00%       0.000us       0.000us       2.000us       100.00%       2.000us       2.000us             1  
                         aten::zeros         2.91%      57.000us         5.15%     101.000us     101.000us       0.000us         0.00%       0.000us       0.000us             1  
                         aten::empty         0.92%      18.000us         0.92%      18.000us       9.000us       0.000us         0.00%       0.000us       0.000us             2  
                         aten::zero_         1.48%      29.000us         1.48%      29.000us      29.000us       0.000us         0.00%       0.000us       0.000us             1  
                     cudaMemcpyAsync         2.70%      53.000us         2.70%      53.000us      53.000us       0.000us         0.00%       0.000us       0.000us             1  
               cudaStreamSynchronize         0.31%       6.000us         0.31%       6.000us       6.000us       0.000us         0.00%       0.000us       0.000us             1  
               cudaDeviceSynchronize         0.51%      10.000us         0.51%      10.000us      10.000us       0.000us         0.00%       0.000us       0.000us             1  
------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 1.962ms
Self CUDA time total: 2.000us

-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                     Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
              aten::zeros         5.13%      12.000us         8.12%      19.000us      19.000us             1  
              aten::empty         3.42%       8.000us         3.42%       8.000us       4.000us             2  
              aten::zero_         0.43%       1.000us         0.43%       1.000us       1.000us             1  
               compute_th        69.23%     162.000us        88.46%     207.000us     207.000us             1  
             aten::select        15.81%      37.000us        18.38%      43.000us      43.000us             1  
         aten::as_strided         2.56%       6.000us         2.56%       6.000us       6.000us             1  
    cudaDeviceSynchronize         3.42%       8.000us         3.42%       8.000us       8.000us             1  
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 234.000us
$ CUDA_LAUNCH_BLOCKING=1 python profiler.py 
------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                          compute_th        82.48%       1.224ms        91.71%       1.361ms       1.361ms       0.000us         0.00%       2.000us       2.000us             1  
                          aten::item         2.36%      35.000us         9.03%     134.000us     134.000us       0.000us         0.00%       2.000us       2.000us             1  
           aten::_local_scalar_dense         2.63%      39.000us         6.67%      99.000us      99.000us       2.000us       100.00%       2.000us       2.000us             1  
    Memcpy DtoH (Device -> Pageable)         0.00%       0.000us         0.00%       0.000us       0.000us       2.000us       100.00%       2.000us       2.000us             1  
                         aten::zeros         4.58%      68.000us         7.68%     114.000us     114.000us       0.000us         0.00%       0.000us       0.000us             1  
                         aten::empty         0.94%      14.000us         0.94%      14.000us       7.000us       0.000us         0.00%       0.000us       0.000us             2  
                         aten::zero_         2.36%      35.000us         2.36%      35.000us      35.000us       0.000us         0.00%       0.000us       0.000us             1  
                     cudaMemcpyAsync         3.71%      55.000us         3.71%      55.000us      55.000us       0.000us         0.00%       0.000us       0.000us             1  
               cudaStreamSynchronize         0.34%       5.000us         0.34%       5.000us       5.000us       0.000us         0.00%       0.000us       0.000us             1  
               cudaDeviceSynchronize         0.61%       9.000us         0.61%       9.000us       9.000us       0.000us         0.00%       0.000us       0.000us             1  
------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 1.484ms
Self CUDA time total: 2.000us

-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                     Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
              aten::zeros         4.69%      12.000us         7.42%      19.000us      19.000us             1  
              aten::empty         3.12%       8.000us         3.12%       8.000us       4.000us             2  
              aten::zero_         0.39%       1.000us         0.39%       1.000us       1.000us             1  
               compute_th        69.14%     177.000us        89.06%     228.000us     228.000us             1  
             aten::select        17.58%      45.000us        19.14%      49.000us      49.000us             1  
         aten::as_strided         1.56%       4.000us         1.56%       4.000us       4.000us             1  
    cudaDeviceSynchronize         3.52%       9.000us         3.52%       9.000us       9.000us             1  
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 256.000us

now, when we change to other operations that implicitly use .item() (according to this), such as max(), we get this:

import torch
from torch.profiler import profile, record_function, ProfilerActivity


def function():
    seed = 0
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    z = torch.rand(1000, 1000).to(DEVICE)
    zs = z.sum().view(1, )
    torch.cuda.synchronize()

    with profile(activities=[ProfilerActivity.CPU,
                             ProfilerActivity.CUDA],
                 record_shapes=True) as prof:
        with record_function("compute_th"):
            z.max()

    print(prof.key_averages().table(sort_by="cuda_time_total",
                                    row_limit=10))
    with profile(activities=[ProfilerActivity.CPU,
                             ProfilerActivity.CUDA],
                 record_shapes=True) as prof:
        with record_function("compute_th"):
            z.max().view(1, )[0]

    print(prof.key_averages().table(sort_by="cuda_time_total",
                                    row_limit=10))


if __name__ == '__main__':
    cuda = "1"
    DEVICE = torch.device(
        "cuda:{}".format(cuda) if torch.cuda.is_available() else "cpu")
    function()

results:

$ python profiler.py 
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                             compute_th         0.01%     265.000us        99.99%        1.910s        1.910s       0.000us         0.00%      25.000us      25.000us             1  
                                              aten::max         0.09%       1.748ms        99.98%        1.909s        1.909s      25.000us       100.00%      25.000us      25.000us             1  
void at::native::reduce_kernel<512, 1, at::native::R...         0.00%       0.000us         0.00%       0.000us       0.000us      24.000us        96.00%      24.000us      24.000us             1  
                                        Memset (Device)         0.00%       0.000us         0.00%       0.000us       0.000us       1.000us         4.00%       1.000us       1.000us             1  
                                            aten::zeros         0.00%      48.000us         0.00%      85.000us      85.000us       0.000us         0.00%       0.000us       0.000us             1  
                                            aten::empty         0.00%      45.000us         0.00%      45.000us      15.000us       0.000us         0.00%       0.000us       0.000us             3  
                                            aten::zero_         0.00%      23.000us         0.00%      23.000us      23.000us       0.000us         0.00%       0.000us       0.000us             1  
                                       aten::as_strided         0.00%       5.000us         0.00%       5.000us       5.000us       0.000us         0.00%       0.000us       0.000us             1  
                                        cudaMemsetAsync         0.00%      31.000us         0.00%      31.000us      31.000us       0.000us         0.00%       0.000us       0.000us             1  
                                       cudaLaunchKernel        99.88%        1.908s        99.88%        1.908s        1.908s       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 1.910s
Self CUDA time total: 25.000us

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                             compute_th         2.10%      73.000us        98.39%       3.422ms       3.422ms       0.000us         0.00%      22.000us      22.000us             1  
                                              aten::max        18.17%     632.000us        94.59%       3.290ms       3.290ms      22.000us       100.00%      22.000us      22.000us             1  
void at::native::reduce_kernel<512, 1, at::native::R...         0.00%       0.000us         0.00%       0.000us       0.000us      21.000us        95.45%      21.000us      21.000us             1  
                                        Memset (Device)         0.00%       0.000us         0.00%       0.000us       0.000us       1.000us         4.55%       1.000us       1.000us             1  
                                            aten::zeros         0.95%      33.000us         1.38%      48.000us      48.000us       0.000us         0.00%       0.000us       0.000us             1  
                                            aten::empty         0.83%      29.000us         0.83%      29.000us       9.667us       0.000us         0.00%       0.000us       0.000us             3  
                                            aten::zero_         0.09%       3.000us         0.09%       3.000us       3.000us       0.000us         0.00%       0.000us       0.000us             1  
                                       aten::as_strided         0.14%       5.000us         0.14%       5.000us       2.500us       0.000us         0.00%       0.000us       0.000us             2  
                                        cudaMemsetAsync         0.86%      30.000us         0.86%      30.000us      30.000us       0.000us         0.00%       0.000us       0.000us             1  
                                       cudaLaunchKernel        75.04%       2.610ms        75.04%       2.610ms       2.610ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 3.478ms
Self CUDA time total: 22.000us
$ CUDA_LAUNCH_BLOCKING=1 python profiler.py 
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                             compute_th         0.02%     247.000us        99.99%        1.585s        1.585s       0.000us         0.00%      27.000us      27.000us             1  
                                              aten::max         0.08%       1.316ms        99.98%        1.585s        1.585s      27.000us       100.00%      27.000us      27.000us             1  
void at::native::reduce_kernel<512, 1, at::native::R...         0.00%       0.000us         0.00%       0.000us       0.000us      25.000us        92.59%      25.000us      25.000us             1  
                                        Memset (Device)         0.00%       0.000us         0.00%       0.000us       0.000us       2.000us         7.41%       2.000us       2.000us             1  
                                            aten::zeros         0.00%      51.000us         0.01%      90.000us      90.000us       0.000us         0.00%       0.000us       0.000us             1  
                                            aten::empty         0.00%      30.000us         0.00%      30.000us      10.000us       0.000us         0.00%       0.000us       0.000us             3  
                                            aten::zero_         0.00%      28.000us         0.00%      28.000us      28.000us       0.000us         0.00%       0.000us       0.000us             1  
                                       aten::as_strided         0.00%       4.000us         0.00%       4.000us       4.000us       0.000us         0.00%       0.000us       0.000us             1  
                                        cudaMemsetAsync         0.00%      27.000us         0.00%      27.000us      27.000us       0.000us         0.00%       0.000us       0.000us             1  
                                       cudaLaunchKernel        99.89%        1.584s        99.89%        1.584s        1.584s       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 1.585s
Self CUDA time total: 27.000us

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                             compute_th         3.66%      78.000us        98.08%       2.091ms       2.091ms       0.000us         0.00%      23.000us      23.000us             1  
                                              aten::max        16.84%     359.000us        91.98%       1.961ms       1.961ms      23.000us       100.00%      23.000us      23.000us             1  
void at::native::reduce_kernel<512, 1, at::native::R...         0.00%       0.000us         0.00%       0.000us       0.000us      22.000us        95.65%      22.000us      22.000us             1  
                                        Memset (Device)         0.00%       0.000us         0.00%       0.000us       0.000us       1.000us         4.35%       1.000us       1.000us             1  
                                            aten::zeros         0.89%      19.000us         1.41%      30.000us      30.000us       0.000us         0.00%       0.000us       0.000us             1  
                                            aten::empty         1.22%      26.000us         1.22%      26.000us       8.667us       0.000us         0.00%       0.000us       0.000us             3  
                                            aten::zero_         0.09%       2.000us         0.09%       2.000us       2.000us       0.000us         0.00%       0.000us       0.000us             1  
                                       aten::as_strided         0.28%       6.000us         0.28%       6.000us       3.000us       0.000us         0.00%       0.000us       0.000us             2  
                                        cudaMemsetAsync         1.31%      28.000us         1.31%      28.000us      28.000us       0.000us         0.00%       0.000us       0.000us             1  
                                       cudaLaunchKernel        72.94%       1.555ms        72.94%       1.555ms       1.555ms       0.000us         0.00%       0.000us       0.000us             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 2.132ms
Self CUDA time total: 23.000us

Note that in your benchmark, when you do zs[0] you create a new cuda Tensor containing that value but you do not actually sent it to the cpu. So it is expected for this to be faster than item().
You can call .cpu() on it or try and print it (which does the .cpu() implicitly) to make sure you compare the same thing.

{method ‘cpu’ of ‘torch._C._TensorBase’ objects} this method more time i my program how to fix this @albanD

The cpu() operation will push the device tensor to the CPU, will synchronize the code, and will thus accumulate the times of already running kernels, so it might be a red herring.

how to fix this error

It’s not an error but could be wrong profiling. If you want to check each operation you would either have to synchronize the code manually, use torch.utils.benchmark, or look at the timeline using a profiler.


used torch.utilsbottleneck same resuklt only


same result method cpu

give any solution bro

Hi,

I am not sure this is an appropriate way to ask for help!

For your question, you can actually check the warning in the doc of the function you’re using: torch.utils.bottleneck — PyTorch 1.9.1 documentation
It does say that because CUDA is async, if you do not do any syncing, the ops that synchronize (like cpu()) will have abnormally high impact.

The advice above was to use the benchmark tool which does have a mode to properly time cuda code.

i don’t know how to use benchmARK TOOL FOR MY CODE PLEASE GIVE IDEA

We’ve already provided information about the needed synchronizations in your cross-posts and linked to resources which give you examples or the profiling utilities.
It seems you are unwilling to check these resources and assume that others will “solve” the issues for you.

I agree with @albanD that your way of asking for help is inappropriate, so I would like to ask you to reconsider your posting behavior in this board.

1 Like

thank for u apporach

you are right,I want to use item() when print. How can I reduce this time ?

As previously described, item() will synchronize the code and wait for the GPU to finish its computations, since you are explicitly transferring the tensor to the CPU and are creating a Python literal. Since its value must be known before the operation is executed, the code is synchronized.
If you want to avoid the synchronizations, use item() less often or store the detached CUDATensor and print it later (once you are fine with a sync point).