Do operations between tensors and scalars move the tensor to CPU?

Hello,
if I perform a multiplication between a tensor that is on GPU and a float, does this operation move the tensor to CPU?

E.g.

scalar = 0.5
t = torch.ones(10, device='cuda')
res = t * scalar

I see that the result res will be on GPU, and also t didn’t change device. However, I’m wondering if the operation moves t to CPU first, performs the computation, and then moves the result to GPU.

Thanks for any clarifications!

No, as this would result in terrible performance since you would move potentially large data, execute the operation on the slower CPU, and move the result back. Instead the scalar is passed to the kernel.

1 Like

Perfect, thank you @ptrblck !

So I assume that at this point it would be better to use

scalar = torch.tensor(0.5, device='cuda')

instead of a plain float. Do you think it would make any difference?

No, it would be worse as you will trigger an explicit memcpy instead of allowing PyTorch to use a specialized TensorIterator kernel lifting the scalar to a kernel parameter.
This code snippet shows the effect:

import torch
from torch.profiler import profile, record_function, ProfilerActivity

x = torch.randn(10, device="cuda")

with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
    with record_function("scalar"):
        y = x * 0.5
print(prof.key_averages().table(sort_by="cuda_time_total"))
# -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
#                                                    Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
# -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
#                                                  scalar        26.74%     231.000us        99.42%     859.000us     859.000us       0.000us         0.00%       1.000us       1.000us             1  
#                                               aten::mul        42.01%     363.000us        72.69%     628.000us     628.000us       1.000us       100.00%       1.000us       1.000us             1  
# void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us       1.000us       100.00%       1.000us       1.000us             1  
#                                        cudaLaunchKernel        30.67%     265.000us        30.67%     265.000us     265.000us       0.000us         0.00%       0.000us       0.000us             1  
#                                   cudaDeviceSynchronize         0.58%       5.000us         0.58%       5.000us       5.000us       0.000us         0.00%       0.000us       0.000us             1  
# -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------


with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
    with record_function("tensor"):
        y = x * torch.tensor(0.5, device="cuda")
print(prof.key_averages().table(sort_by="cuda_time_total"))
# -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
#                                                    Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
# -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
#                                                  tensor         4.82%      82.000us        99.82%       1.700ms       1.700ms       0.000us         0.00%       2.000us       2.000us             1  
#                                               aten::mul         0.88%      15.000us         1.47%      25.000us      25.000us       2.000us       100.00%       2.000us       2.000us             1  
# void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us       2.000us       100.00%       2.000us       2.000us             1  
#                                             aten::empty         0.18%       3.000us         0.18%       3.000us       3.000us       0.000us         0.00%       0.000us       0.000us             1  
#                                                aten::to        76.34%       1.300ms        93.19%       1.587ms       1.587ms       0.000us         0.00%       0.000us       0.000us             1  
#                                          aten::_to_copy         0.65%      11.000us        16.85%     287.000us     287.000us       0.000us         0.00%       0.000us       0.000us             1  
#                                     aten::empty_strided         0.53%       9.000us         0.53%       9.000us       9.000us       0.000us         0.00%       0.000us       0.000us             1  
#                                             aten::copy_         0.70%      12.000us        15.68%     267.000us     267.000us       0.000us         0.00%       0.000us       0.000us             1  
#                                         cudaMemcpyAsync        14.86%     253.000us        14.86%     253.000us     253.000us       0.000us         0.00%       0.000us       0.000us             1  
#                                   cudaStreamSynchronize         0.12%       2.000us         0.12%       2.000us       2.000us       0.000us         0.00%       0.000us       0.000us             1  
#                        Memcpy HtoD (Pageable -> Device)         0.00%       0.000us         0.00%       0.000us       0.000us       0.000us         0.00%       0.000us       0.000us             1  
#                                        aten::lift_fresh         0.00%       0.000us         0.00%       0.000us       0.000us       0.000us         0.00%       0.000us       0.000us             1  
#                                           aten::detach_         0.12%       2.000us         0.18%       3.000us       3.000us       0.000us         0.00%       0.000us       0.000us             1  
#                                                 detach_         0.06%       1.000us         0.06%       1.000us       1.000us       0.000us         0.00%       0.000us       0.000us             1  
#                                        cudaLaunchKernel         0.59%      10.000us         0.59%      10.000us      10.000us       0.000us         0.00%       0.000us       0.000us             1  
#                                   cudaDeviceSynchronize         0.18%       3.000us         0.18%       3.000us       3.000us       0.000us         0.00%       0.000us       0.000us             1  
# -------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------ 

and the specialization is described here:

// The gpu_kernel_with_scalars generates specializations that support a
// single scalar CPU argument, such as from cuda_tensor + 5. The CPU scalar
// is lifted to a kernel parameter instead of copying to device memory.
// This should be used in conjunction with TensorIterator::allow_cpu_scalars_,
// which is the default for TensorIterator::binary_op. Otherwise, all inputs
// and the output must be on the GPU.

2 Likes

Thanks for the clarifications @ptrblck !

1 Like

I thought the profiling code may not answer the question correctly because the tensors in the test are tiny. And we create a new gpu scalar tensor every time. So I re-wrote the test code to be more realistic as follows - but the result is the same!

import torch
from torch.profiler import profile, record_function, ProfilerActivity

x = torch.empty((100, 100), dtype=torch.float32, device="cuda")
val_s = torch.tensor(0.5, device="cuda")
s = 0.5

with profile(
    activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
    schedule=torch.profiler.schedule(
        wait=2,
        warmup=2,
        active=21,
    ),
) as prof:
    with record_function("scalar"):
        for _ in range(50 * 25):
            x.mul_(s)
            prof.step()
print(prof.key_averages().table(sort_by="cuda_time_total"))


with profile(
    activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
    schedule=torch.profiler.schedule(
        wait=2,
        warmup=2,
        active=21,
    ),
) as prof:
    with record_function("scalar"):
        for _ in range(50 * 25):
            x.mul_(val_s)
            prof.step()
print(prof.key_averages().table(sort_by="cuda_time_total"))

This printed:

-----------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
             Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls
-----------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
    ProfilerStep*        71.24%     441.000us       100.00%     619.000us      29.476us     453.000us        60.48%     749.000us      35.667us            21
       aten::mul_        28.76%     178.000us        28.76%     178.000us       8.476us     296.000us        39.52%     296.000us      14.095us            21
-----------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
Self CPU time total: 619.000us
Self CUDA time total: 749.000us



-----------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
             Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls
-----------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
    ProfilerStep*        68.62%     468.000us       100.00%     682.000us      32.476us     630.000us        55.65%       1.132ms      53.905us            21
       aten::mul_        31.38%     214.000us        31.38%     214.000us      10.190us     502.000us        44.35%     502.000us      23.905us            21
-----------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
Self CPU time total: 682.000us
Self CUDA time total: 1.132ms

I did notice that for small tensors this relationship doesn’t hold, but that’s just measuring overhead I think.