Why does ZeroRedundancyOptimizer consume less memory during training than FullyShardedDataParallel?

I trained a Transformer model with 1B parameters on servers with 8 A100 GPUs.

As far as I know, ZeroRedundancyOptimizer is based on ZeRO-1 and FullyShardedDataParallel is based on ZeRO-3, and FSDP should reduce more GPU memory consumption.

However, in reality, when using ZeroRedundancyOptimizer, the GPU memory consumption during training is lower.

Is this a normal phenomenon?

The pytorch version I am using is 2.0.1, the following is the code I used when testing:

# FSDP
import torch
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.api import ShardingStrategy


device = torch.device('cuda')
sharded_module = FSDP(my_module, sharding_strategy=ShardingStrategy.FULL_SHARD)

optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001)
dummy_input = torch.randn((1, 3, 224, 224), device=device)

sharded_module.train()
x = sharded_module(dummy_input)
loss = x.sum()
loss.backward()
optim.step()
optim.zero_grad()


# ZeroRedundancyOptimizer
import torch
from torch.distributed.optim.zero_redundancy_optimizer import ZeroRedundancyOptimizer
from torch.nn.parallel import DataParallel, DistributedDataParallel


device = torch.device('cuda')
model = DistributedDataParallel(
    my_module, device_ids=[torch.cuda.current_device()],
            find_unused_parameters=True)

optim = ZeroRedundancyOptimizer(model.parameters(), torch.optim.Adam,
                                lr=0.0001)
dummy_input = torch.randn((1, 3, 224, 224), device=device)

model.train()
x = model(dummy_input)
loss = x.sum()
loss.backward()
optim.step()
optim.zero_grad()

When you use ZeroRedundancyOptimizer did you use DDP?

cc: @agu

Yes, I used DDP when I used ZeroRedundancyOptimizer.

How are you using FSDP? Are you using nested FSDP wrapping?

My code is already posted in the description, I am not using nested FSDP warpping.

Sorry, I missed the code. If you do not use nested FSDP wrapping, then you will incur 1 + 1 / W memory overhead from the parameters for W workers (due to FSDP always maintaining the 1 / W shard), which could be why you see higher GPU memory usage with FSDP.

Sorry, I did not understand your meaning. Are you saying that if FSDP wants to achieve lower GPU memory consumption, it needs nested warpping? Is there any sample code for reference?

Yes, nested wrapping helps lower the GPU memory usage.

You can try something like:

from torch.distributed.fsdp.wrap import ModuleWrapPolicy
policy = ModuleWrapPolicy({TransformerBlock})
sharded_module = FSDP(my_module, auto_wrap_policy=policy, ...)

where you can replace TransformerBlock with whatever your transformer block’s module class is.

1 Like

Thanks, it does work! Is it true that ZeroRedundancyOptimizer is already nested wrapped, while FSDP requires setting a policy to achieve nested wrapping?

The FSDP wrapping effectively determines the max parameter size that is unsharded (all-gathered) at once, so if you only wrap at the top-level, then all parameters contribute to that max size. Therefore, nested wrapping is necessary to decrease the max memory contribution from parameters.

For ZeroRedundancyOptimizer, the implementation is different. The optimizer states are greedily partitioned across ranks to minimize unevenness. It does not shard parameters or gradients.

Thanks for your reply, it’s really helpful. But I still have some questions. Is ZeroRedundancyOptimizer an implementation of ZeRO-1? If so, why not continue implementing ZeRO-2 and ZeRO-3 based on ZeroRedundancyOptimizer instead of introducing FSDP separately? Is there a reason for this?

One reason is that ZeroRedundancyOptimizer is a torch.optim.Optimizer, but in order to shard parameters/gradients, we need something that interfaces with nn.Module. Hence, FullyShardedDataParallel is an nn.Module (wrapper).

I see, thanks for your patience.

Hello, I still have a question that I don’t understand. When FSDP does not do nested warpping, the memory consumption is higher than ZeroRedundancyOptimizer, but FSDP actually includes ZeRO-1. Is this caused by the different implementation between FSDP and ZeroRedundancyOptimizer? Where is the extra memory consumption of FSDP reflected?

It is an implementation detail of FSDP. See the description from above:

If you do not use nested FSDP wrapping, then you will incur 1 + 1 / W memory overhead from the parameters for W workers (due to FSDP always maintaining the 1 / W shard), which could be why you see higher GPU memory usage with FSDP.

I understand, thank you!