Question about GPU memory usage when using pipeline parallelism training under larger micro batch count

I am using torchtian with FSDP2 + PP(1F1B) to train llama3-8b, however, I found that as the micro batch count increasing, the GPU memory usage will increase rapidly (from 42.27GB to 64.57GB) on the last pp stage. That’s a bit strange, AFAK, the GPU memory should not increase too much though use larger micro batch count (After all, we need to use larger micro batch count to decrease the bubble rate).

Here is my experiment settings and results:

using 4 GPUs and DP2-PP2 to train llama3-8b (pruned, 16 layers)
pytorch version: 2.9.0.dev20250725+cu126
torchtitan version: v0.1.0

command:

LOG_RANK=2 NGPU=4 CONFIG_FILE=./torchtitan/models/llama3/train_configs/llama3_8b.toml ./run_train.sh --metrics.log_freq 1  --training.seq_len 4096 --training.steps 1000 --parallelism.data_parallel_shard_degree 2 --activation_checkpoint.mode full --parallelism.pipeline_parallel_degree 2  --training.local_batch_size 16

memory: torch max reserved memory
rank0 is pp stage0, rank2 is pp stage1 (last stage)

micro batch count memory (GB) on rank0 memory (GB) on rank2
2 35.75 40.30
4 35.80 42.27
8 35.94 48.27
16 36.20 65.47

Also, I dumped the memory snapshot of last PP stage and found that the output (shape: 4096 * 128256) of LM head in each micro batch keeps alive during the step period

I checked the implementation of pipeline parallelism and found that PyTorch will caching chunk outputs after forward_maybe_with_no_sync in forward_one_chunk function, this may explains the GPU memory behavior of the above, but my question is: Do we really need to cache the chunk output during training, perhaps we only need to cache output during evaluation or inference ? Or let the users choose whether enable caching outputs or not.

After I tried to disable caching outputs with some modifications like below, the GPU memory usage looks normal:

class _PipelineStageBase(ABC):
    def __init__(self, ...):
        ...
        # use `requires_merge_outputs` as a switch
        self.requires_merge_outputs: bool = True
    
    def forward_one_chunk(self, ...):
        ...
        output = self.forward_maybe_with_nosync(*composite_args, **composite_kwargs)
        if self.is_last and self.requires_merge_outputs:
            self.output_chunks.append(output)
        ...

class PipelineScheduleSingle(_PipelineSchedule):
    def step(self, *args, target=None, losses: Optional[list] = None, **kwargs):
        ...
        if self._stage.is_last and self._stage.requires_merge_outputs:
            return self._merge_outputs(self._stage.output_chunks)
        else:
            return None
    
    # expose to the user
    def set_requires_merge_outputs(self, requires_merge_outputs: bool):
        self._stage.requires_merge_outputs = requires_merge_outputs

Now the GPU memory usage more stable

micro batch count memory (GB) on rank0 memory (GB) on rank2
2 35.75 40.30
4 35.80 40.38
8 35.94 40.41
16 36.20 41.64

If we don’t save the outputs for the last PP stage, then the memory bottleneck will probably be in them middle stages so the memory required may not differ that much. Have you tested to see what is the max memory allocated for the other ranks?

You still bring up a good point though and that feature seems like something we should support. We have a new eval() API, so users can use that instead if they need the outputs saved. Can you file an issue on github? GitHub · Where software is built, we can better track it there

EDIT: nvm I see you already filed an issue, thanks!

@H-Huang I supplemented the max reserved GPU memory on other ranks(or stages) in the previous table, it seems that whether cache the forward outputs on last PP stage or not has no effect on the memory footprint of other ranks.

By the way, you indicated that “If we don’t save the outputs for the last PP stage, then the memory bottleneck will probably be in them middle stages”. I find this point a bit confusing, could you help me understand the reasoning behind this in a bit more detail when you have time.

Thanks in advance !

I don’t mean to say that the caching will affect the memory of the middle stages, I just wonder how the memory profile of the entire training job now looks.

The max memory usage should be calculated as max(all_ranks) since if any one of the ranks OOMs then the training will fail. So I was just wondering how the memory usage of the middle ranks look like. If last stage was the previous bottleneck where is the new bottleneck, was just curious :slight_smile:

Got it, thanks for clarifying!