Multi-task learning: Bottleneck, multi-GPU

I wanted to explore the following model architecture:

A Llama-2 model that has two different task heads: One for classification and the other for causal language modeling. My goal is to train one model that can take a text as input, give it a score using the classification head, as well as generate a response using the causal language modeling head.

In practice, I implemented it by first creating two separate Llama-2 models (one cls_model: LlamaForSequenceClassification and one clm_model: LlamaForCausalLM), then replacing the classifier’s base model with the causal LM’s base model so that they’re linked:

del cls_model.model
cls_model.model = clm_model.model

A single forward pass basically looks like this (not literally, I summarized it for brevity):

loss = torch.tensor([0])
for task, task_data in datapoint.items():
    task_model = models[task]
    loss = loss + task_model(**task_data).loss
loss.backward()

While a naive implementation of this seems to be working, it takes too long in practice.

Problem 1

I expected a larger speed increase when choosing larger batch sizes. A batch of size 1 takes roughly 1.3s while a batch of 10 takes about 10s. I expected the two batches to take about the same time since batches are processed by the GPU in parallel, am I wrong in that assumption? If I’m not, what could be the reason for this behavior?
I tried to profile my code using torch.utils.bottleneck but I’m not sure how to read it.

--------------------------------------------------------------------------------
  Environment Summary
--------------------------------------------------------------------------------
PyTorch 2.3.1 DEBUG compiled w/ CUDA 12.1
Running with Python 3.10 and CUDA 12.1.105

`pip3 list` truncated output:
numpy==1.26.4
optree==0.11.0
torch==2.3.1
torchaudio==2.3.1
torchelastic==0.2.2
torchvision==0.18.1
triton==2.3.1
--------------------------------------------------------------------------------
  cProfile output
--------------------------------------------------------------------------------
         9476410 function calls (9048717 primitive calls) in 69.741 seconds

   Ordered by: internal time
   List reduced from 19478 to 15 due to restriction <15>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        4   27.139    6.785   27.139    6.785 {method 'run_backward' of 'torch._C._EngineBase' objects}
     8299   11.034    0.001   11.034    0.001 {method 'to' of 'torch._C.TensorBase' objects}
      256   10.440    0.041   10.440    0.041 {built-in method torch.nonzero}
      184    3.270    0.018    3.270    0.018 {method 'read' of '_ssl._SSLSocket' objects}
        9    1.371    0.152    1.371    0.152 {built-in method gc.collect}
       98    1.282    0.013    1.511    0.015 /opt/conda/lib/python3.10/site-packages/datasets/packaged_modules/json/json.py:91(_generate_tables)
     8880    0.709    0.000    0.709    0.000 {method 'write' of '_io.BufferedWriter' objects}
      200    0.509    0.003    0.509    0.003 {built-in method sentencepiece._sentencepiece.SentencePieceProcessor__EncodeAsPieces}
        2    0.424    0.212   12.241    6.121 /opt/conda/lib/python3.10/site-packages/transformers/modeling_utils.py:3930(_load_pretrained_model)
      384    0.409    0.001    0.409    0.001 {method 'uniform_' of 'torch._C.TensorBase' objects}
      356    0.357    0.001    0.357    0.001 {built-in method torch.tensor}
     2824    0.342    0.000    0.342    0.000 {built-in method torch._C._nn.linear}
     3782    0.322    0.000    0.322    0.000 {built-in method marshal.loads}
        6    0.300    0.050    0.300    0.050 {built-in method safetensors._safetensors_rust.serialize_file}
10948/10591    0.268    0.000    0.717    0.000 {method 'read' of '_io.BufferedReader' objects}


--------------------------------------------------------------------------------
  autograd profiler output (CPU mode)
--------------------------------------------------------------------------------
        top 15 events sorted by cpu_time_total

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
enumerate(DataLoader)#_SingleProcessDataLoaderIter._...        49.63%     512.807ms        49.91%     515.670ms     515.670ms             1  
enumerate(DataLoader)#_SingleProcessDataLoaderIter._...        47.41%     489.910ms        47.67%     492.578ms     492.578ms             1  
autograd::engine::evaluate_function: CheckpointFunct...         0.00%      48.000us        10.99%     113.508ms     113.508ms             1  
                             CheckpointFunctionBackward         0.48%       4.936ms        10.98%     113.460ms     113.460ms             1  
autograd::engine::evaluate_function: CheckpointFunct...         0.01%      63.000us        10.97%     113.303ms     113.303ms             1  
                             CheckpointFunctionBackward         0.46%       4.715ms        10.96%     113.240ms     113.240ms             1  
autograd::engine::evaluate_function: CheckpointFunct...         0.01%      84.000us        10.96%     113.225ms     113.225ms             1  
                             CheckpointFunctionBackward         0.52%       5.355ms        10.95%     113.141ms     113.141ms             1  
autograd::engine::evaluate_function: CheckpointFunct...         0.01%      85.000us        10.87%     112.331ms     112.331ms             1  
                             CheckpointFunctionBackward         0.47%       4.840ms        10.86%     112.246ms     112.246ms             1  
autograd::engine::evaluate_function: CheckpointFunct...         0.00%      28.000us        10.70%     110.606ms     110.606ms             1  
                             CheckpointFunctionBackward         0.53%       5.439ms        10.70%     110.578ms     110.578ms             1  
autograd::engine::evaluate_function: CheckpointFunct...         0.01%     100.000us        10.66%     110.153ms     110.153ms             1  
                             CheckpointFunctionBackward         0.46%       4.804ms        10.65%     110.053ms     110.053ms             1  
autograd::engine::evaluate_function: CheckpointFunct...         0.01%      61.000us        10.65%     110.015ms     110.015ms             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 1.033s

--------------------------------------------------------------------------------
  autograd profiler output (CUDA mode)
--------------------------------------------------------------------------------
        top 15 events sorted by cpu_time_total

        Because the autograd profiler uses the CUDA event API,
        the CUDA time column reports approximately max(cuda_time, cpu_time).
        Please ignore this output if your code does not use CUDA.

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
enumerate(DataLoader)#_SingleProcessDataLoaderIter._...        49.04%     550.595ms        49.80%     559.157ms     559.157ms     549.640ms        50.79%     559.161ms     559.161ms             1  
enumerate(DataLoader)#_SingleProcessDataLoaderIter._...        46.98%     527.473ms        47.74%     535.991ms     535.991ms     525.553ms        48.56%     535.996ms     535.996ms             1  
autograd::engine::evaluate_function: CheckpointFunct...         0.01%      72.000us        11.97%     134.347ms     134.347ms       6.000us         0.00%     117.401ms     117.401ms             1  
                             CheckpointFunctionBackward         0.63%       7.124ms        11.96%     134.269ms     134.269ms       1.358ms         0.13%     117.395ms     117.395ms             1  
autograd::engine::evaluate_function: CheckpointFunct...         0.01%      60.000us        11.87%     133.307ms     133.307ms       7.000us         0.00%     115.531ms     115.531ms             1  
                             CheckpointFunctionBackward         0.63%       7.065ms        11.87%     133.242ms     133.242ms       1.159ms         0.11%     115.524ms     115.524ms             1  
autograd::engine::evaluate_function: CheckpointFunct...         0.01%      77.000us        11.84%     132.936ms     132.936ms       6.000us         0.00%     115.295ms     115.295ms             1  
                             CheckpointFunctionBackward         0.74%       8.335ms        11.83%     132.854ms     132.854ms       1.143ms         0.11%     115.289ms     115.289ms             1  
autograd::engine::evaluate_function: CheckpointFunct...         0.01%      93.000us        11.81%     132.612ms     132.612ms       6.000us         0.00%     114.613ms     114.613ms             1  
                             CheckpointFunctionBackward         0.62%       6.997ms        11.80%     132.515ms     132.515ms       1.088ms         0.10%     114.607ms     114.607ms             1  
autograd::engine::evaluate_function: CheckpointFunct...         0.01%     135.000us        10.39%     116.621ms     116.621ms       7.000us         0.00%     116.508ms     116.508ms             1  
autograd::engine::evaluate_function: CheckpointFunct...         0.01%      69.000us        10.38%     116.556ms     116.556ms       7.000us         0.00%     116.466ms     116.466ms             1  
                             CheckpointFunctionBackward         0.66%       7.381ms        10.37%     116.482ms     116.482ms       1.127ms         0.10%     116.459ms     116.459ms             1  
                             CheckpointFunctionBackward         0.64%       7.213ms        10.37%     116.481ms     116.481ms       1.131ms         0.10%     116.501ms     116.501ms             1  
autograd::engine::evaluate_function: CheckpointFunct...         0.01%     106.000us        10.37%     116.463ms     116.463ms       7.000us         0.00%     116.442ms     116.442ms             1  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 1.123s
Self CUDA time total: 1.082s

It sounds like dataloading is the bottleneck. However, even when I pre-load the entire dataset into CUDA memory before training like this

train_dataloader = list(train_dataloader)
for i, batch in enumerate(train_dataloader):
    # one batch has form {"cls": {"input_ids": [...], ...}, "clm": {"input_ids": [...], ...}
    for task in batch:    # task is either "cls" or "clm"
        train_dataloader[i][task] = {k: v.cuda() for k, v in batch[task].items()

and then measure the time inside my dataset’s __getitem__() function, it’s only 0.05s out of a 1.3s iteration. I don’t know why the profiler says it takes 0.5s.

Problem 2

I tried to implement DDP, mainly by following this official guide by PyTorch. Unfortunately, when I do that, I get the following error message:

return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: Expected to mark a variable ready only once. This error is caused by one of the following reasons: 1) Use of a module parameter outside the `forward` function. Please make sure model parameters are not shared across multiple concurrent forward-backward passes. or try to use _set_static_graph() as a workaround if this module graph does not change during training loop.2) Reused parameters in multiple reentrant backward passes. For example, if you use multiple `checkpoint` functions to wrap the same part of your model, it would result in the same set of parameters been used by different reentrant backward passes multiple times, and hence marking a variable ready multiple times. DDP does not support such use cases in default. You can try to use _set_static_graph() as a workaround if your module graph does not change over iterations.
Parameter at index 127 with name _orig_mod.base_model.model.model.layers.31.self_attn.v_proj.lora_B.default.weight has been marked as ready twice. This means that multiple autograd engine  hooks have fired for this particular parameter during this iteration.

It sounds like I can’t run my architecture in DDP, did I understand it correctly? The error message suggests using _set_static_graph() “if this module graph does not change over iterations” but in my case it does since I pass the data through the classifier part and the causal LM part separately, right? I tried a bunch of different stuff after researching this error but it all leads back to this error.

Problem 3

After I couldn’t get DDP to work, I decided to try to load the classifier and the causal LM onto separate GPUs instead so that they can run in parallel during an iteration. However, here, I get this error message, even though I made sure the data and model for a specific task are both on the same device

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0!

Even if this worked, I realized that moving the two models to separate GPUs would uncouple them again, effectively making them two completely separate models again, right? So they don’t share the common Llama-2 backbone anymore and backward()ing becomes very difficult to implement.

I’d appreciate it if you could share any ideas at all that come to mind.