Unexpected execution time difference for identical operations on GPU

Hello everyone,
I am training a transformer with a custom FeedForward layer and it’s causing unexpected efficiency issues. This layer entails two heavy batched matrix multiplications, implemented with torch.bmm

My code is basically:
def my_layer(x):
x = torch.bmm(x,tensor_1) #call it operation1
x = torch.relu_(x)
x = torch.bmm(x,tensor_2) #call it operation2

In my case, the shape of x is always torch.Size([1024, 256, 512]), and shape of tensor_1 and tensor_2 is torch.Size([1024, 512, 512]).

As seen above, both operation1 and operation2 implement the exact same type of operation with tensors that are identical in shape and type (the shape of x stays the same throughout this code), and yet the first one takes significantly more time than the other. The timings look as follows:

The difference is more pronounced with mixed precision enabled:
<picture 2 in the reply - the site only allows 1 image per post>

I then took to looking at what would happen if I followed the two operations with their repeated uses, i.e. operation1operation2operation1operation1operation2operation2
(relu omitted for brevity)

The results are as follows:
<picture 3 in the reply>
We can see that only the first operation execution takes more time.
With mixed precision enabled, even more interesting things happen:
<picture 4 in the reply>
As we see, the first execution is the longest for a given operation, and the further ones all take the same amount of time.

I have two questions:

  1. Why are execution times different for the first use of operation1 and operation2?
  2. Why does mixed precision cause the repeated use of the same parameters to take less time to do the calculations?

My goal is to shorten the time for both operation1 and operation2, by bridging the gap between them and, if possible, getting them down to the time displayed in operation_1_again and operation_2_again.

The way I measure time is standard: i use cuda events for start and end of measurement, followed by torch.cuda.synchronise()

Any insights into inner workings of CUDA (I am using A100 80GB and PyTorch 2.0) and mixed precision intricacies would be greatly appreciated.

2 Likes

pic 2

1 Like

pic 3

1 Like

pic 4

1 Like

Could you check which kernels are called in both use cases and if they are the same? Also, did you allow TF32 to be used in matmuls?

As to allowing TF32, I did not take any explicit steps to do that.
As for checking which kernel are used, I’m happy to check, but what would be the preffered way of doing that?

You can use the native profiler and just follow our tutorial or e.g. Nsight Systems.

OK, this is the context manager I used:

        with profile(
            activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
            record_shapes=True,
            profile_memory=True,
            use_cuda=True,
        ) as prof:
            with record_function("cont_moe"):

I used the original setup with only 1 use of each operation (if the others are needed, I’m happy to run the tests again).

  1. Mixed precision enabled:
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                               cont_moe         1.10%      76.000us        14.69%       1.016ms       1.016ms       0.000us         0.00%       6.143ms       6.143ms           0 b           0 b       1.50 Gb           0 b             1  
                                              aten::bmm         0.98%      68.000us        14.53%       1.005ms     251.250us       1.388ms        22.59%       9.050ms       2.263ms           0 b           0 b       2.00 Gb     256.00 Mb             4  
                                               aten::to         0.04%       3.000us        11.45%     792.000us     396.000us       0.000us         0.00%       2.600ms       1.300ms           0 b           0 b       1.00 Gb           0 b             2  
                                         aten::_to_copy         0.33%      23.000us        11.41%     789.000us     394.500us       0.000us         0.00%       2.600ms       1.300ms           0 b           0 b       1.00 Gb           0 b             2  
                                    aten::empty_strided         5.68%     393.000us         5.68%     393.000us     196.500us       0.000us         0.00%       0.000us       0.000us           0 b           0 b       1.00 Gb       1.00 Gb             2  
                                            aten::copy_         0.36%      25.000us         5.63%     389.000us     129.667us       4.437ms        72.23%       4.437ms       1.479ms           0 b           0 b           0 b           0 b             3  
                                       cudaLaunchKernel         5.52%     382.000us         5.52%     382.000us      63.667us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b             6  
void at::native::unrolled_elementwise_kernel<at::nat...         0.00%       0.000us         0.00%       0.000us       0.000us       2.600ms        42.32%       2.600ms       1.300ms           0 b           0 b           0 b           0 b             2  
                                            aten::clone         0.07%       5.000us         0.43%      30.000us      30.000us       0.000us         0.00%       1.837ms       1.837ms           0 b           0 b     256.00 Mb           0 b             1  
                                       aten::empty_like         0.07%       5.000us         0.13%       9.000us       9.000us       0.000us         0.00%       0.000us       0.000us           0 b           0 b     256.00 Mb           0 b             1  
                                            aten::empty         0.06%       4.000us         0.06%       4.000us       4.000us       0.000us         0.00%       0.000us       0.000us           0 b           0 b     256.00 Mb     256.00 Mb             1  
cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFla...         0.04%       3.000us         0.04%       3.000us       0.750us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b             4  
                                            aten::relu_         0.26%      18.000us         0.52%      36.000us      36.000us       0.000us         0.00%     318.000us     318.000us           0 b           0 b           0 b           0 b             1  
                                       aten::clamp_min_         0.16%      11.000us         0.26%      18.000us      18.000us     318.000us         5.18%     318.000us     318.000us           0 b           0 b           0 b           0 b             1  
                                  cudaDeviceSynchronize        85.31%       5.899ms        85.31%       5.899ms       5.899ms       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b             1  
void at::native::elementwise_kernel<128, 4, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us       1.837ms        29.90%       1.837ms       1.837ms           0 b           0 b           0 b           0 b             1  
ampere_fp16_s16816gemm_fp16_128x128_ldg8_f2f_stages_...         0.00%       0.000us         0.00%       0.000us       0.000us       1.388ms        22.59%       1.388ms     694.000us           0 b           0 b           0 b           0 b             2  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     318.000us         5.18%     318.000us     318.000us           0 b           0 b           0 b           0 b             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 6.915ms
Self CUDA time total: 6.143ms
  1. Mixed precision disabled:

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                               cont_moe         0.28%      58.000us         4.28%     884.000us     884.000us       0.000us         0.00%      17.387ms      17.387ms           0 b           0 b       1.00 Gb           0 b             1  
                                              aten::bmm         1.85%     381.000us         3.85%     794.000us     397.000us      14.583ms        83.87%      16.763ms       8.382ms           0 b           0 b       1.00 Gb     512.00 Mb             2  
                                            aten::clone         0.02%       5.000us         1.93%     398.000us     398.000us       0.000us         0.00%       2.180ms       2.180ms           0 b           0 b     512.00 Mb           0 b             1  
                                       aten::empty_like         0.03%       7.000us         0.05%      11.000us      11.000us       0.000us         0.00%       0.000us       0.000us           0 b           0 b     512.00 Mb           0 b             1  
                                            aten::empty         0.02%       4.000us         0.02%       4.000us       4.000us       0.000us         0.00%       0.000us       0.000us           0 b           0 b     512.00 Mb     512.00 Mb             1  
                                            aten::copy_         0.06%      13.000us         1.85%     382.000us     382.000us       2.180ms        12.54%       2.180ms       2.180ms           0 b           0 b           0 b           0 b             1  
                                       cudaLaunchKernel         1.88%     388.000us         1.88%     388.000us      97.000us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b             4  
cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFla...         0.01%       3.000us         0.01%       3.000us       0.250us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b            12  
                                            aten::relu_         0.08%      16.000us         0.16%      32.000us      32.000us       0.000us         0.00%     624.000us     624.000us           0 b           0 b           0 b           0 b             1  
                                       aten::clamp_min_         0.04%       9.000us         0.08%      16.000us      16.000us     624.000us         3.59%     624.000us     624.000us           0 b           0 b           0 b           0 b             1  
                                  cudaDeviceSynchronize        95.72%      19.748ms        95.72%      19.748ms      19.748ms       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b             1  
void at::native::elementwise_kernel<128, 2, at::nati...         0.00%       0.000us         0.00%       0.000us       0.000us       2.180ms        12.54%       2.180ms       2.180ms           0 b           0 b           0 b           0 b             1  
                                ampere_sgemm_128x128_nn         0.00%       0.000us         0.00%       0.000us       0.000us      14.583ms        83.87%      14.583ms       7.292ms           0 b           0 b           0 b           0 b             2  
void at::native::vectorized_elementwise_kernel<4, at...         0.00%       0.000us         0.00%       0.000us       0.000us     624.000us         3.59%     624.000us     624.000us           0 b           0 b           0 b           0 b             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 20.632ms
Self CUDA time total: 17.387ms

@ptrblck does it look as expected? Thanks