Is there a way to shard a subset of the weights and replicate others to minimize the communication overhead?
I have tried following approaches, but they don’t work.
import os
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
import torch.distributed as dist
from torch.distributed.fsdp import fully_shard
from torch.distributed.tensor import Shard, Replicate, DTensor
class Block(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.ModuleList()
for _ in range(3):
self.layers.append(nn.Linear(10, 10))
def forward(self, x):
for layer in self.layers:
x = F.relu(layer(x))
return x
class TestModel(nn.Module):
def __init__(self):
super().__init__()
self.blocks = nn.ModuleList()
for _ in range(2):
self.blocks.append(Block())
def forward(self, x):
for block in self.blocks:
x = block(x)
return x
def get_ignored_params(model):
ignored_params = set()
for name, param in model.named_parameters():
# replicate if param is 1-dim or less than 4Mb
if param.ndim == 1:
ignored_params.add(param)
elif (np.prod(param.shape) * param.dtype.itemsize) / 1e6 < 4.0:
ignored_params.add(param)
return ignored_params
def shard_placement_fn(param):
if param.ndim == 1:
return Replicate()
elif (np.prod(param.shape) * param.dtype.itemsize) / 1e6 < 4.0:
return Replicate()
return Shard(0)
def main():
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
dist.init_process_group("nccl")
rank = dist.get_rank()
device_id = rank % torch.cuda.device_count()
device = torch.device(f"cuda:{device_id}")
with torch.device("meta"):
model = TestModel()
fsdp_kwargs = {}
"""
Setting 'ignored_params' does not give an error but non-sharded params are not synced (different value and grad!)
Setting 'shard_plancement_fn' leads to AttributeError: 'Replicate' object has no attribute 'dim'
"""
# fsdp_kwargs["ignored_params"] = get_ignored_params(model)
# fsdp_kwargs["shard_placement_fn"] = shard_placement_fn
for i, block in enumerate(model.blocks):
model.blocks[i] = fully_shard(block, **fsdp_kwargs)
model = fully_shard(model, **fsdp_kwargs)
model.to_empty(device="cuda")
for _, submodule in model.named_modules():
if hasattr(submodule, "reset_parameters"):
submodule.reset_parameters()
optimizer = torch.optim.AdamW(model.parameters(), 1e-2)
x = torch.ones(5, 10).to(device)
for i in range(10):
optimizer.zero_grad()
loss = model(x).mean()
loss.backward()
optimizer.step()
last_bias = model.blocks[-1].layers[-1].bias
if isinstance(last_bias, DTensor):
last_bias = last_bias.full_tensor()
print(f"[RANK {rank}] {i}, {loss}, {last_bias}, {last_bias.grad}", flush=True)
if __name__ == "__main__":
main()