Losses of FSDP NO_SHARD and DDP doesn't match

Hi there! I found the training losses of FSDP NO_SHARD and DDP doesn’t match each other. In my understanding, they are the same algorithm and should have the same loss curve. Or I’m wrong because they are different implementations, and we shouldn’t expect them to have the same loss.

I’m running the following code (modified on mnist from pytorch/examples). The losses match on 1 or 2 GPUs but don’t match on 2+ GPUs.

from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.distributed as dist

from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import ShardingStrategy
from torch.distributed.fsdp.wrap import always_wrap_policy
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output


def setup():
    if "OMPI_COMM_WORLD_SIZE" in os.environ:
        world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", 1))
        rank = int(os.environ.get("OMPI_COMM_WORLD_RANK", 0))
        master_ip = os.environ.get("MASTER_IP", "localhost")
        master_port = "8999"
        print(f"Initializing distributed 'tcp://{master_ip}:{master_port}'")
        dist.init_process_group(
            backend="nccl",
            init_method=f"tcp://{master_ip}:{master_port}",
            world_size=world_size,
            rank=rank,
        )
    else:
        dist.init_process_group(backend="nccl", init_method="env://")
        rank = dist.get_rank()
        world_size = dist.get_world_size()


def tear():
    dist.destroy_process_group()


def train(args, model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % args.log_interval == 0:
            loss_t = loss.detach().clone()
            dist.all_reduce(loss_t)
            loss_t /= dist.get_world_size()
            if dist.get_rank() == 0:
                print(
                    "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
                        epoch,
                        batch_idx * len(data),
                        len(train_loader.dataset),
                        100.0 * batch_idx / len(train_loader),
                        loss_t.item(),
                    )
                )
            if args.dry_run:
                break


def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            # sum up batch loss
            test_loss += F.nll_loss(output, target, reduction="sum").item()
            # get the index of the max log-probability
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print(
        "\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format(
            test_loss, correct, len(test_loader.dataset), 100.0 * correct / len(test_loader.dataset)
        )
    )


def main():
    setup()

    # Training settings
    parser = argparse.ArgumentParser(description="PyTorch MNIST Example")
    parser.add_argument(
        "--dist",
        "-d",
        type=str,
        choices=("ddp", "fsdp"),
        required=True,
    )
    parser.add_argument(
        "--batch-size",
        type=int,
        default=64,
        metavar="N",
        help="input batch size for training (default: 64)",
    )
    parser.add_argument(
        "--test-batch-size",
        type=int,
        default=1000,
        metavar="N",
        help="input batch size for testing (default: 1000)",
    )
    parser.add_argument(
        "--epochs",
        type=int,
        default=14,
        metavar="N",
        help="number of epochs to train (default: 14)",
    )
    parser.add_argument(
        "--lr", type=float, default=1.0, metavar="LR", help="learning rate (default: 1.0)"
    )
    parser.add_argument(
        "--gamma",
        type=float,
        default=0.7,
        metavar="M",
        help="Learning rate step gamma (default: 0.7)",
    )
    parser.add_argument(
        "--no-cuda", action="store_true", default=False, help="disables CUDA training"
    )
    parser.add_argument(
        "--dry-run", action="store_true", default=False, help="quickly check a single pass"
    )
    parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)")
    parser.add_argument(
        "--log-interval",
        type=int,
        default=10,
        metavar="N",
        help="how many batches to wait before logging training status",
    )
    parser.add_argument(
        "--save-model", action="store_true", default=False, help="For Saving the current Model"
    )
    args = parser.parse_args()
    use_cuda = not args.no_cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)
    torch.cuda.set_device(dist.get_rank())
    device = torch.cuda.current_device()
    print(f"Using cuda:{device}")

    train_kwargs = {"batch_size": args.batch_size}
    test_kwargs = {"batch_size": args.test_batch_size}
    if use_cuda:
        cuda_kwargs = {"num_workers": 1, "pin_memory": True}
        train_kwargs.update(cuda_kwargs)
        test_kwargs.update(cuda_kwargs)

    transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
    )
    dataset1 = datasets.MNIST("data", train=True, download=True, transform=transform)
    dataset2 = datasets.MNIST("data", train=False, transform=transform)
    dataset1_sampler = torch.utils.data.distributed.DistributedSampler(
        dataset1, dist.get_world_size(), dist.get_rank(), shuffle=True
    )
    dataset2_sampler = torch.utils.data.distributed.DistributedSampler(
        dataset2, dist.get_world_size(), dist.get_rank(), shuffle=False
    )
    train_loader = torch.utils.data.DataLoader(dataset1, sampler=dataset1_sampler, **train_kwargs)
    test_loader = torch.utils.data.DataLoader(dataset2, sampler=dataset2_sampler, **test_kwargs)

    model = Net().to(device)
    for param in model.parameters():
        dist.broadcast(param, 0)
    if args.dist == "ddp":
        model = DDP(model)
    else:
        model = FSDP(
            model,
            sharding_strategy=ShardingStrategy.NO_SHARD,
            auto_wrap_policy=always_wrap_policy,
            device_id=dist.get_rank(),
        )
    print(model)
    optimizer = optim.Adadelta(model.parameters(), lr=args.lr)

    scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
    for epoch in range(1, args.epochs + 1):
        train(args, model, device, train_loader, optimizer, epoch)
        test(model, device, test_loader)
        scheduler.step()

    if args.save_model:
        torch.save(model.state_dict(), "mnist_cnn.pt")

    tear()


if __name__ == "__main__":
    import os

    os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
    torch.backends.cudnn.benchmark = False
    torch.use_deterministic_algorithms(True)

    main()

I use mpi to launch multiple processes.

mpirun -np 4 python main.py -d fsdp
mpirun -np 4 python main.py -d ddp

System information:

GPU: V100
Python: 3.8.13
PyTorch: 1.12.1+cu116
torch.version.cuda: 11.6
torch.cuda.nccl.version(): 2.10.3

Cc @Yanli_Zhao @Rohan_Varma

yes, FSDP.NO_SHARD has different implementation from DDP implementation, specifically their bucketing strategy for optimization is a little different, so their losses may not be exactly the same.

I see. Thanks a lot for the explaination!