I trained a Transformer model with 1B parameters on servers with 8 A100 GPUs.
As far as I know, ZeroRedundancyOptimizer is based on ZeRO-1 and FullyShardedDataParallel is based on ZeRO-3, and FSDP should reduce more GPU memory consumption.
However, in reality, when using ZeroRedundancyOptimizer, the GPU memory consumption during training is lower.
Is this a normal phenomenon?
The pytorch version I am using is 2.0.1, the following is the code I used when testing:
# FSDP
import torch
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.api import ShardingStrategy
device = torch.device('cuda')
sharded_module = FSDP(my_module, sharding_strategy=ShardingStrategy.FULL_SHARD)
optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001)
dummy_input = torch.randn((1, 3, 224, 224), device=device)
sharded_module.train()
x = sharded_module(dummy_input)
loss = x.sum()
loss.backward()
optim.step()
optim.zero_grad()
# ZeroRedundancyOptimizer
import torch
from torch.distributed.optim.zero_redundancy_optimizer import ZeroRedundancyOptimizer
from torch.nn.parallel import DataParallel, DistributedDataParallel
device = torch.device('cuda')
model = DistributedDataParallel(
my_module, device_ids=[torch.cuda.current_device()],
find_unused_parameters=True)
optim = ZeroRedundancyOptimizer(model.parameters(), torch.optim.Adam,
lr=0.0001)
dummy_input = torch.randn((1, 3, 224, 224), device=device)
model.train()
x = model(dummy_input)
loss = x.sum()
loss.backward()
optim.step()
optim.zero_grad()
Sorry, I missed the code. If you do not use nested FSDP wrapping, then you will incur 1 + 1 / W memory overhead from the parameters for W workers (due to FSDP always maintaining the 1 / W shard), which could be why you see higher GPU memory usage with FSDP.
Sorry, I did not understand your meaning. Are you saying that if FSDP wants to achieve lower GPU memory consumption, it needs nested warpping? Is there any sample code for reference?
Thanks, it does work! Is it true that ZeroRedundancyOptimizer is already nested wrapped, while FSDP requires setting a policy to achieve nested wrapping?
The FSDP wrapping effectively determines the max parameter size that is unsharded (all-gathered) at once, so if you only wrap at the top-level, then all parameters contribute to that max size. Therefore, nested wrapping is necessary to decrease the max memory contribution from parameters.
For ZeroRedundancyOptimizer, the implementation is different. The optimizer states are greedily partitioned across ranks to minimize unevenness. It does not shard parameters or gradients.
Thanks for your reply, it’s really helpful. But I still have some questions. Is ZeroRedundancyOptimizer an implementation of ZeRO-1? If so, why not continue implementing ZeRO-2 and ZeRO-3 based on ZeroRedundancyOptimizer instead of introducing FSDP separately? Is there a reason for this?
One reason is that ZeroRedundancyOptimizer is a torch.optim.Optimizer, but in order to shard parameters/gradients, we need something that interfaces with nn.Module. Hence, FullyShardedDataParallel is an nn.Module (wrapper).
Hello, I still have a question that I don’t understand. When FSDP does not do nested warpping, the memory consumption is higher than ZeroRedundancyOptimizer, but FSDP actually includes ZeRO-1. Is this caused by the different implementation between FSDP and ZeroRedundancyOptimizer? Where is the extra memory consumption of FSDP reflected?
It is an implementation detail of FSDP. See the description from above:
If you do not use nested FSDP wrapping, then you will incur 1 + 1 / W memory overhead from the parameters for W workers (due to FSDP always maintaining the 1 / W shard), which could be why you see higher GPU memory usage with FSDP.