Can I shard a subset of weights and replicate others in FSDP2?

Is there a way to shard a subset of the weights and replicate others to minimize the communication overhead?

I have tried following approaches, but they don’t work.

import os

import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
import torch.distributed as dist
from torch.distributed.fsdp import fully_shard
from torch.distributed.tensor import Shard, Replicate, DTensor


class Block(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.ModuleList()
        for _ in range(3):
            self.layers.append(nn.Linear(10, 10))

    def forward(self, x):
        for layer in self.layers:
            x = F.relu(layer(x))
        return x


class TestModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.blocks = nn.ModuleList()
        for _ in range(2):
            self.blocks.append(Block())

    def forward(self, x):
        for block in self.blocks:
            x = block(x)
        return x


def get_ignored_params(model):
    ignored_params = set()
    for name, param in model.named_parameters():
        # replicate if param is 1-dim or less than 4Mb
        if param.ndim == 1:
            ignored_params.add(param)
        elif (np.prod(param.shape) * param.dtype.itemsize) / 1e6 < 4.0:
            ignored_params.add(param)
    return ignored_params


def shard_placement_fn(param):
    if param.ndim == 1:
        return Replicate()
    elif (np.prod(param.shape) * param.dtype.itemsize) / 1e6 < 4.0:
        return Replicate()
    return Shard(0)


def main():
    torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
    dist.init_process_group("nccl")
    rank = dist.get_rank()
    device_id = rank % torch.cuda.device_count()
    device = torch.device(f"cuda:{device_id}")

    with torch.device("meta"):
        model = TestModel()

    fsdp_kwargs = {}
    """
    Setting 'ignored_params' does not give an error but non-sharded params are not synced (different value and grad!)
    Setting 'shard_plancement_fn' leads to AttributeError: 'Replicate' object has no attribute 'dim'
    """
    # fsdp_kwargs["ignored_params"] = get_ignored_params(model)
    # fsdp_kwargs["shard_placement_fn"] = shard_placement_fn

    for i, block in enumerate(model.blocks):
        model.blocks[i] = fully_shard(block, **fsdp_kwargs)
    model = fully_shard(model, **fsdp_kwargs)

    model.to_empty(device="cuda")
    for _, submodule in model.named_modules():
        if hasattr(submodule, "reset_parameters"):
            submodule.reset_parameters()

    optimizer = torch.optim.AdamW(model.parameters(), 1e-2)

    x = torch.ones(5, 10).to(device)

    for i in range(10):
        optimizer.zero_grad()
        loss = model(x).mean()
        loss.backward()
        optimizer.step()
        last_bias = model.blocks[-1].layers[-1].bias
        if isinstance(last_bias, DTensor):
            last_bias = last_bias.full_tensor()
        print(f"[RANK {rank}] {i}, {loss}, {last_bias}, {last_bias.grad}", flush=True)


if __name__ == "__main__":
    main()

1 Like