Alternating Parameters in DDP

Hi Guys,

My question and context might a bit long, but I think it will help a lot of people trying to train large models with limited GPU memory under a DDP setup. After trying a lot of stuff (flash-attention, fp16, etc.), I am trying to leverage the property that setting requires_grad=False saves GPU memory.

A most intuitive description of what I want is that:

  • I have a neural network with two layers layer1 and layer2.
  • I use two separate optimizers: opt1 for layer1 and opt2 for layer2.
  • During the training process, I alternate between the two optimizers and set the layers with requires_grad=False.

Then the following code works perfectly well under the single-GPU use-case and squeeze the training right below the GPU memory:

# initialization
net = ...
opt1, opt2 = ..., ...

# training
for step, data in enumerate(data_loader):
    # alternating the parameter sets
    if step % 2 == 0:
        for param in net.layer1.parameters():
            param.requires_grad = True
            param.grad = torch.zeros_like(param.data)
        for param in net.layer2.parameters():
            param.requires_grad = False
            del param.grad
        opt = opt1
    else:
        for param in net.layer2.parameters():
            param.requires_grad = True
            param.grad = torch.zeros_like(param.data)
        for param in net.layer1.parameters():
            param.requires_grad = False
            del param.grad
        opt = opt2
    
    loss = net(data)
    loss.backward()
    opt.step()
    opt.zero_grad()

Then I try to switch to multi-GPU training with DDP as below, adding something supporting the DDP. Please note that I wrap DDP after each alternation of parameter set so that DDP can correctly register and reduce the gradients.

# initialization
net = ...
ddp_net = DDP(model)
opt1, opt2 = ..., ...

for step, data in enumerate(data_loader):
    unwrap_net = unwrap(ddp_net) # remove the ddp wrapper
    # alternating the parameter sets
    if step % 2 == 0:
        for param in unwrap_net.layer1.parameters():
            param.requires_grad = True
            param.grad = torch.zeros_like(param.data)
        for param in unwrap_net.layer2.parameters():
            param.requires_grad = False
            del param.grad
        opt = opt1
    else:
        for param in unwrap_net.layer2.parameters():
            param.requires_grad = True
            param.grad = torch.zeros_like(param.data)
        for param in unwrap_net.layer1.parameters():
            param.requires_grad = False
            del param.grad
        opt = opt2
    
    new_ddp_net = DDP(unwrap_net)
    del ddp_net
    ddp_net = new_ddp_net
    torch.cuda.empty_cache()
    
    loss = ddp_net(data)
    loss.backward()
    opt.step()
    opt.zero_grad()

Here’s the question: I found that [alternating the parameter set] occupies more GPU memory than [training only layer1 or layer2 alone with DDP]. Although it is just 1-2GB, I can no longer squeeze the large model into my GPU anymore ( :sob: :sob: :sob:). So I am wondering:

  • Are there any place I am doing wrong that causes this issue?
  • Are there better approaches to alternating the parameter sets under a DDP setting?
  • Considering that I am using huggingface’s accelerate, is my observed issue not a pytorch one, actually?

If you have read through here, my friend, I am truly grateful for your patience and help! May the force of optimization be with you!

Best,

Ziqi

hi Ziqi,

I’m not sure if you’ve solved this problem since March 2024 :joy: but it’s actually a good question which I’m also curious about. So I asked ChatGPT, and below are the answers from it:

Quick summary — root causes and high-level fixes

Your extra ~1–2GB of GPU memory mostly comes from three places:

  1. DDP internal state / communication buffers — wrapping a model in DistributedDataParallel registers reduction hooks and allocates buckets/communication buffers for the parameter set at wrap time. Frequently unwrapping / re-wrapping or modifying parameters after wrapping can leave extra allocations or mismatch hooks.

  2. Optimizer state kept on GPU — Adam/momentum/etc. states (exp_avg, exp_avg_sq, momentum buffers) live in GPU memory. Holding two optimizers or optimizer states for different parameter groups simultaneously multiplies GPU memory usage.

  3. find_unused_parameters=True and unused-parameter handling — if parameters are sometimes unused in a forward, DDP bookkeeping (or enabling find_unused_parameters) causes extra memory and CPU/GPU overhead.

Short fixes: don’t change model parameters (or requires_grad) after wrapping DDP; if you must alternate parameter groups, free or offload optimizer state for frozen groups; consider FSDP/ZeRO/optimizer-sharding for a robust long-term solution.


Why exactly you see +1–2GB

  • DDP registers hooks & buckets at wrap time. The DDP wrapper assumes the model’s parameter set is stable after wrapping — altering parameters or frequently re-wrapping can create extra internal allocations, mismatch hooks, or temporary buffers.

  • Optimizer state is expensive. Each parameter that’s optimized often has one or more state tensors. If you keep two optimizers (or two sets of optimizer states) on GPU, both sets consume GPU memory concurrently. In DDP this effect is per-process, so memory pressure increases quickly.

  • Unused-parameter detection has cost. If some parameters are not touched in a forward, DDP either errors or you enable find_unused_parameters=True, which triggers additional bookkeeping and memory usage.


Immediate, practical ways to reduce GPU memory (ranked)

  1. Release / move optimizer state for frozen params — when you freeze a parameter group, move its optimizer state tensors to CPU (or delete them). This often yields the biggest immediate memory win.

  2. Avoid frequent re-wrap/unwrapping of DDP — if you must reconstruct DDP, be sure to del old objects, gc.collect(), call torch.distributed.barrier() and torch.cuda.empty_cache() (note: empty_cache() releases PyTorch cache, not necessarily OS memory). But prefer avoiding re-wraps at runtime.

  3. Adopt FSDP / ZeRO / optimizer sharding — these solutions shard parameters/gradients/optimizer state across ranks or offload state to CPU; they are the correct long-term approach for large-model memory problems.

  4. If you can’t avoid unused parameters: use find_unused_parameters=True only when necessary, accepting the performance and memory overhead.

  5. When reconstructing DDP, fully destroy old optimizer/state — don’t retain old optimizer objects with GPU tensors.


Code snippets you can try right away

A. Move optimizer state for frozen params to CPU (or delete it):

# opt: current optimizer
# params_to_free: iterable of parameter objects (same objects used as keys in opt.state)
def offload_optimizer_state_to_cpu(opt, params_to_free):
    params_to_free = set(params_to_free)
    for p in list(opt.state.keys()):
        if p in params_to_free:
            state = opt.state[p]
            for k, v in list(state.items()):
                if isinstance(v, torch.Tensor):
                    # move tensor to CPU to free GPU memory
                    state[k] = v.detach().to('cpu')
            # alternatively, to delete the state entirely:
            # del opt.state[p]

Note: if you later re-enable training for those params you must reinitialize their optimizer state (or accept fresh state).

B. Properly destroy and rebuild DDP and optimizer (if you must):

# before switching groups:
del ddp_model
del optimizer
gc.collect()
torch.distributed.barrier()
torch.cuda.empty_cache()

# change requires_grad on the underlying model
for p in model.layer1.parameters():
    p.requires_grad = True
for p in model.layer2.parameters():
    p.requires_grad = False

# rebuild DDP and optimizer
ddp_model = torch.nn.parallel.DistributedDataParallel(model, ...)
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, ddp_model.parameters()), lr=...)

Warning: reconstructing frequently is costly; avoid this in the training inner loop.

C. Longer-term: use FSDP / optimizer-sharding (conceptual)
Use PyTorch FSDP or ZeRO-like sharding (or fairscale/accelerate integrations) to shard parameters, gradients, and optimizer state across ranks so single-GPU memory requirements drop dramatically.


Debugging tips to find where memory goes

  • Print torch.cuda.memory_summary(device) to see allocated vs reserved memory.

  • Check nvidia-smi to inspect per-process GPU memory usage.

  • After each major step (wrap, optimizer creation, freeze/unfreeze), call gc.collect() and torch.cuda.empty_cache() and observe changes.

  • Make sure no other Python objects hold references to model tensors (lists, logging buffers, checkpoints, etc.) which prevent freeing.


Final recommendation

  • Short-term: Inspect optimizer state and offload or delete state tensors for frozen parameter groups. Avoid holding two full optimizer states on GPU simultaneously.

  • Medium/long-term: Move to FSDP / ZeRO / optimizer-sharding if memory is a recurring bottleneck — these are the robust solutions designed for alternating parameter usage or very large models.

If you want, I can rewrite a concrete portion of your training loop to (1) offload optimizer state when freezing, or (2) show a minimal FSDP example (with accelerate or raw PyTorch FSDP). Paste the relevant code and I’ll convert it into a drop-in change.