Hi there! I found the training losses of FSDP NO_SHARD and DDP doesn’t match each other. In my understanding, they are the same algorithm and should have the same loss curve. Or I’m wrong because they are different implementations, and we shouldn’t expect them to have the same loss.
I’m running the following code (modified on mnist from pytorch/examples). The losses match on 1 or 2 GPUs but don’t match on 2+ GPUs.
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import ShardingStrategy
from torch.distributed.fsdp.wrap import always_wrap_policy
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout(0.25)
self.dropout2 = nn.Dropout(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
output = F.log_softmax(x, dim=1)
return output
def setup():
if "OMPI_COMM_WORLD_SIZE" in os.environ:
world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", 1))
rank = int(os.environ.get("OMPI_COMM_WORLD_RANK", 0))
master_ip = os.environ.get("MASTER_IP", "localhost")
master_port = "8999"
print(f"Initializing distributed 'tcp://{master_ip}:{master_port}'")
dist.init_process_group(
backend="nccl",
init_method=f"tcp://{master_ip}:{master_port}",
world_size=world_size,
rank=rank,
)
else:
dist.init_process_group(backend="nccl", init_method="env://")
rank = dist.get_rank()
world_size = dist.get_world_size()
def tear():
dist.destroy_process_group()
def train(args, model, device, train_loader, optimizer, epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if batch_idx % args.log_interval == 0:
loss_t = loss.detach().clone()
dist.all_reduce(loss_t)
loss_t /= dist.get_world_size()
if dist.get_rank() == 0:
print(
"Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
epoch,
batch_idx * len(data),
len(train_loader.dataset),
100.0 * batch_idx / len(train_loader),
loss_t.item(),
)
)
if args.dry_run:
break
def test(model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
# sum up batch loss
test_loss += F.nll_loss(output, target, reduction="sum").item()
# get the index of the max log-probability
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print(
"\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format(
test_loss, correct, len(test_loader.dataset), 100.0 * correct / len(test_loader.dataset)
)
)
def main():
setup()
# Training settings
parser = argparse.ArgumentParser(description="PyTorch MNIST Example")
parser.add_argument(
"--dist",
"-d",
type=str,
choices=("ddp", "fsdp"),
required=True,
)
parser.add_argument(
"--batch-size",
type=int,
default=64,
metavar="N",
help="input batch size for training (default: 64)",
)
parser.add_argument(
"--test-batch-size",
type=int,
default=1000,
metavar="N",
help="input batch size for testing (default: 1000)",
)
parser.add_argument(
"--epochs",
type=int,
default=14,
metavar="N",
help="number of epochs to train (default: 14)",
)
parser.add_argument(
"--lr", type=float, default=1.0, metavar="LR", help="learning rate (default: 1.0)"
)
parser.add_argument(
"--gamma",
type=float,
default=0.7,
metavar="M",
help="Learning rate step gamma (default: 0.7)",
)
parser.add_argument(
"--no-cuda", action="store_true", default=False, help="disables CUDA training"
)
parser.add_argument(
"--dry-run", action="store_true", default=False, help="quickly check a single pass"
)
parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)")
parser.add_argument(
"--log-interval",
type=int,
default=10,
metavar="N",
help="how many batches to wait before logging training status",
)
parser.add_argument(
"--save-model", action="store_true", default=False, help="For Saving the current Model"
)
args = parser.parse_args()
use_cuda = not args.no_cuda and torch.cuda.is_available()
torch.manual_seed(args.seed)
torch.cuda.set_device(dist.get_rank())
device = torch.cuda.current_device()
print(f"Using cuda:{device}")
train_kwargs = {"batch_size": args.batch_size}
test_kwargs = {"batch_size": args.test_batch_size}
if use_cuda:
cuda_kwargs = {"num_workers": 1, "pin_memory": True}
train_kwargs.update(cuda_kwargs)
test_kwargs.update(cuda_kwargs)
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)
dataset1 = datasets.MNIST("data", train=True, download=True, transform=transform)
dataset2 = datasets.MNIST("data", train=False, transform=transform)
dataset1_sampler = torch.utils.data.distributed.DistributedSampler(
dataset1, dist.get_world_size(), dist.get_rank(), shuffle=True
)
dataset2_sampler = torch.utils.data.distributed.DistributedSampler(
dataset2, dist.get_world_size(), dist.get_rank(), shuffle=False
)
train_loader = torch.utils.data.DataLoader(dataset1, sampler=dataset1_sampler, **train_kwargs)
test_loader = torch.utils.data.DataLoader(dataset2, sampler=dataset2_sampler, **test_kwargs)
model = Net().to(device)
for param in model.parameters():
dist.broadcast(param, 0)
if args.dist == "ddp":
model = DDP(model)
else:
model = FSDP(
model,
sharding_strategy=ShardingStrategy.NO_SHARD,
auto_wrap_policy=always_wrap_policy,
device_id=dist.get_rank(),
)
print(model)
optimizer = optim.Adadelta(model.parameters(), lr=args.lr)
scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
for epoch in range(1, args.epochs + 1):
train(args, model, device, train_loader, optimizer, epoch)
test(model, device, test_loader)
scheduler.step()
if args.save_model:
torch.save(model.state_dict(), "mnist_cnn.pt")
tear()
if __name__ == "__main__":
import os
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)
main()
I use mpi to launch multiple processes.
mpirun -np 4 python main.py -d fsdp
mpirun -np 4 python main.py -d ddp
System information:
GPU: V100
Python: 3.8.13
PyTorch: 1.12.1+cu116
torch.version.cuda: 11.6
torch.cuda.nccl.version(): 2.10.3