How can I optimize the GPU memory transfer time when dynamically loading network weights during the forward pass in PyTorch?

I have implemented a library to dynamically load network weights during the forward pass of a network in order to reduce the device memory usage. The weights are sharded and loaded independently for each module. This works well on CPU, but when using GPU, the transfer time of the tensors takes a long time and greatly slows down the computation (90% of the time is spent loading tensors).

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                        model_inference        95.41%      709.555s       100.00%      743.710s      743.710s       0.000us         0.00%       35.203s       35.203s       1.87 Gb       1.87 Gb       3.18 Gb       2.96 Gb             1  
                                        cudaMemcpyAsync         4.35%       32.383s         4.35%       32.383s       7.243ms     142.235ms         0.44%     142.235ms      31.813us           0 b           0 b           0 b           0 b          4471  
                                  cudaStreamSynchronize         0.03%     243.736ms         0.03%     243.736ms      54.637us     497.634ms         1.54%     497.634ms     111.552us           0 b           0 b           0 b           0 b          4461  
                                            aten::copy_         0.03%     208.714ms         4.40%       32.746s       2.479ms       31.600s        98.08%       32.393s       2.453ms           0 b           0 b           0 b           0 b         13207  
                                       cudaLaunchKernel         0.02%     169.343ms         0.02%     169.343ms       9.537us        1.625s         5.04%        1.625s      91.505us           0 b           0 b      -8.00 Kb      -8.00 Kb         17757  
                                            aten::empty         0.02%     140.044ms         0.02%     140.044ms      11.568us       0.000us         0.00%       0.000us       0.000us         268 b         268 b       2.17 Mb       2.17 Mb         12106  
                                             aten::set_         0.02%     138.528ms         0.02%     138.528ms      12.366us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b         11202  
                                    aten::empty_strided         0.01%     106.856ms         0.01%     106.856ms       8.462us       0.000us         0.00%       0.000us       0.000us           0 b           0 b      63.85 Mb      63.85 Mb         12628  
                                         aten::_to_copy         0.01%      76.763ms         0.03%     249.047ms      26.240us       0.000us         0.00%     292.004ms      30.766us           0 b           0 b      32.47 Mb           0 b          9491  
                                               aten::mm         0.01%      73.769ms         0.01%     109.808ms      46.022us     540.683ms         1.68%        1.047s     438.988us           0 b           0 b      47.87 Mb      47.87 Mb          2386  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 743.710s
Self CUDA time total: 32.220s

Is there a way to optimize the GPU memory transfer time in PyTorch when dynamically loading network weights?

Here is a link to the project’s GitHub and here is a link to a notebook for running inference.

To perform the profiling I used this code:

from torch.profiler import profile, record_function, ProfilerActivity

PROMPT = "translate English to German: How old are you?"
input_ids = tokenizer(PROMPT, return_tensors="pt").input_ids
input_ids = input_ids.to(DEVICE)

torch.manual_seed(0)
with profile(activities=[
        ProfilerActivity.CPU, ProfilerActivity.CUDA], profile_memory=True, record_shapes=True) as prof:
    with record_function("model_inference"):
        odewel_model.generate(input_ids, max_new_tokens=10)   

print(prof.key_averages().table(sort_by="self_cpu_time_total"))     

Copies to GPUs are the main bottleneck you’ll face, you can look into CUDA graphs as a technique to address, easiest way to use them is torch.compile(m, mode="reduce-overhead")

You can also look at instantiating tensors on GPU directly by setting torch.set_default_device('cuda') in your code

and finally you might want to look at fast GPU decoding with technologies like NVIDIA DALI to help

1 Like