FSDP with model parallel

Hello,

I need to implement FSDP in a model parallel setup. I want my encoder to run on a single GPU and the decoder to run on another GPU while harnessing the memory saving options, optimization options, and distributed training options that I get with FSDP.

I have a computer with 4 GPUs. I am running the following without a model parallel setup with no errors.

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
from tqdm import tqdm
import os
import torch.distributed as dist
from functools import partial
from torch.distributed.fsdp import (
    FullyShardedDataParallel as FSDP,
    MixedPrecision,
    BackwardPrefetch,
    ShardingStrategy,
)
from torch.distributed.fsdp.wrap import _module_wrap_policy


# Define model
class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 4, 3, padding="same"),
            nn.MaxPool2d(2),
            nn.Conv2d(4, 8, 3, padding="same"),
            nn.MaxPool2d(2),
        )
        self.decoder = nn.Sequential(
            nn.Flatten(),
            nn.Linear(7 * 7 * 8, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
        )

    def forward(self, x):
        x = self.encoder(x)
        logits = self.decoder(x)
        return logits


def train(dataloader, model, loss_fn, optimizer, rank):
    model.train()
    with tqdm(
        total=len(dataloader), postfix={"loss": "undefined"}, disable=rank != 0
    ) as pbar:
        for X, y in dataloader:

            # Compute prediction error
            pred = model(X)
            y = y.to(pred.device)
            loss = loss_fn(pred, y)

            # Backpropagation
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            pbar.set_postfix({"loss": loss.cpu().item()})
            pbar.update(1)


def test(dataloader, model, loss_fn, rank):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    loss_correct_batches = torch.tensor([0, 0, 0]).to(torch.float32)
    with torch.no_grad():
        for X, y in dataloader:
            pred = model(X)
            y = y.to(pred.device)
            loss_correct_batches[0] += loss_fn(pred, y).cpu().item()
            loss_correct_batches[1] += (
                (pred.argmax(1) == y).type(torch.float).sum().cpu().item()
            )
            loss_correct_batches[2] += 1

    loss_correct_batches = loss_correct_batches.to(pred.device)
    dist.all_reduce(loss_correct_batches)

    if rank == 0:
        test_loss, correct, num_batches = loss_correct_batches.cpu().tolist()

        test_loss /= num_batches
        correct /= size
        print(
            f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n"
        )


def get_policies():
    auto_wrap_policy = partial(
        _module_wrap_policy,
        module_classes={nn.Linear, nn.Conv2d},
    )
    sharding_strategy = [
        ShardingStrategy.FULL_SHARD,
        ShardingStrategy.SHARD_GRAD_OP,
        ShardingStrategy.NO_SHARD,
    ][0]
    prefetch_policy = [
        None,
        BackwardPrefetch.BACKWARD_POST,
        BackwardPrefetch.BACKWARD_PRE,  # 13% speed up, 0.59% peak memory increase
    ][2]
    mp_policy = MixedPrecision(
        param_dtype=torch.float16,  # Param precision
        reduce_dtype=torch.float16,  # Gradient communication precision
        buffer_dtype=torch.float16,  # Buffer precision
    )
    return auto_wrap_policy, sharding_strategy, prefetch_policy, mp_policy


if __name__ == "__main__":
    local_rank = int(os.environ["LOCAL_RANK"])
    rank = int(os.environ["RANK"])
    world_size = int(os.environ["WORLD_SIZE"])

    device_count_per_proc = 2
    devices = [
        i + local_rank * device_count_per_proc for i in range(device_count_per_proc)
    ]

    dist.init_process_group("nccl")

    # Download training data from open datasets.
    training_data = datasets.FashionMNIST(
        root="data",
        train=True,
        download=True,
        transform=ToTensor(),
    )

    # Download test data from open datasets.
    test_data = datasets.FashionMNIST(
        root="data",
        train=False,
        download=True,
        transform=ToTensor(),
    )

    batch_size = 64

    # Create data loaders.
    train_dataloader = DataLoader(training_data, batch_size=batch_size)
    test_dataloader = DataLoader(test_data, batch_size=batch_size)

    model = NeuralNetwork()

    auto_wrap_policy, sharding_strategy, prefetch_policy, mp_policy = get_policies()
    torch.cuda.set_device(devices[0])
    model = FSDP(
        model,
        auto_wrap_policy=auto_wrap_policy,
        mixed_precision=mp_policy,
        sharding_strategy=sharding_strategy,
        backward_prefetch=prefetch_policy,
        device_id=devices[0],
    )

    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

    epochs = 5
    for t in range(epochs):
        print(f"Epoch {t+1}\n-------------------------------")
        train(train_dataloader, model, loss_fn, optimizer, rank)
        test(test_dataloader, model, loss_fn, rank)
    print("Done!")

    dist.destroy_process_group()

I am invoking this by

torchrun --nnodes 1 --nproc_per_node 2 ./pt-basics-dist.py

But when I try to do this in a model parallel setup while initializing the encoder and decoder with different devices as follows, I get an “Exception raised from c10_cuda_check_implementation at …/c10/cuda/CUDAException.cpp:44”.

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
from tqdm import tqdm
import os
import torch.distributed as dist
from functools import partial
from torch.distributed.fsdp import (
    FullyShardedDataParallel as FSDP,
    MixedPrecision,
    BackwardPrefetch,
    ShardingStrategy,
)
from torch.distributed.fsdp.wrap import _module_wrap_policy


# Define model
class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 4, 3, padding="same"),
            nn.MaxPool2d(2),
            nn.Conv2d(4, 8, 3, padding="same"),
            nn.MaxPool2d(2),
        )
        self.decoder = nn.Sequential(
            nn.Flatten(),
            nn.Linear(7 * 7 * 8, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
        )

    def forward(self, x):
        x = self.encoder(x)
        logits = self.decoder(x)
        return logits


def train(dataloader, model, loss_fn, optimizer, rank):
    model.train()
    with tqdm(
        total=len(dataloader), postfix={"loss": "undefined"}, disable=rank != 0
    ) as pbar:
        for X, y in dataloader:

            # Compute prediction error
            pred = model(X)
            y = y.to(pred.device)
            loss = loss_fn(pred, y)

            # Backpropagation
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            pbar.set_postfix({"loss": loss.cpu().item()})
            pbar.update(1)


def test(dataloader, model, loss_fn, rank):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    loss_correct_batches = torch.tensor([0, 0, 0]).to(torch.float32)
    with torch.no_grad():
        for X, y in dataloader:
            pred = model(X)
            y = y.to(pred.device)
            loss_correct_batches[0] += loss_fn(pred, y).cpu().item()
            loss_correct_batches[1] += (
                (pred.argmax(1) == y).type(torch.float).sum().cpu().item()
            )
            loss_correct_batches[2] += 1

    loss_correct_batches = loss_correct_batches.to(pred.device)
    dist.all_reduce(loss_correct_batches)

    if rank == 0:
        test_loss, correct, num_batches = loss_correct_batches.cpu().tolist()

        test_loss /= num_batches
        correct /= size
        print(
            f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n"
        )


def get_policies():
    auto_wrap_policy = partial(
        _module_wrap_policy,
        module_classes={nn.Linear, nn.Conv2d},
    )
    sharding_strategy = [
        ShardingStrategy.FULL_SHARD,
        ShardingStrategy.SHARD_GRAD_OP,
        ShardingStrategy.NO_SHARD,
    ][0]
    prefetch_policy = [
        None,
        BackwardPrefetch.BACKWARD_POST,
        BackwardPrefetch.BACKWARD_PRE,  # 13% speed up, 0.59% peak memory increase
    ][2]
    mp_policy = MixedPrecision(
        param_dtype=torch.float16,  # Param precision
        reduce_dtype=torch.float16,  # Gradient communication precision
        buffer_dtype=torch.float16,  # Buffer precision
    )
    return auto_wrap_policy, sharding_strategy, prefetch_policy, mp_policy


if __name__ == "__main__":
    local_rank = int(os.environ["LOCAL_RANK"])
    rank = int(os.environ["RANK"])
    world_size = int(os.environ["WORLD_SIZE"])

    device_count_per_proc = 2
    devices = [
        i + local_rank * device_count_per_proc for i in range(device_count_per_proc)
    ]

    dist.init_process_group("nccl")

    # Download training data from open datasets.
    training_data = datasets.FashionMNIST(
        root="data",
        train=True,
        download=True,
        transform=ToTensor(),
    )

    # Download test data from open datasets.
    test_data = datasets.FashionMNIST(
        root="data",
        train=False,
        download=True,
        transform=ToTensor(),
    )

    batch_size = 64

    # Create data loaders.
    train_dataloader = DataLoader(training_data, batch_size=batch_size)
    test_dataloader = DataLoader(test_data, batch_size=batch_size)

    model = NeuralNetwork()  # .to(device)

    auto_wrap_policy, sharding_strategy, prefetch_policy, mp_policy = get_policies()
    torch.cuda.set_device(devices[0])
    model.encoder = FSDP(
        model.encoder,
        auto_wrap_policy=auto_wrap_policy,
        mixed_precision=mp_policy,
        sharding_strategy=sharding_strategy,
        backward_prefetch=prefetch_policy,
        device_id=devices[0],
    )
    torch.cuda.set_device(devices[1])
    model.decoder = FSDP(
        model.decoder,
        auto_wrap_policy=auto_wrap_policy,
        mixed_precision=mp_policy,
        sharding_strategy=sharding_strategy,
        backward_prefetch=prefetch_policy,
        device_id=devices[1],
    )

    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

    epochs = 5
    for t in range(epochs):
        print(f"Epoch {t+1}\n-------------------------------")
        train(train_dataloader, model, loss_fn, optimizer, rank)
        test(test_dataloader, model, loss_fn, rank)
    print("Done!")

    dist.destroy_process_group()

I think I am using FSDP in an unintended method.

I would also like to know if there is an easy method to distribute the shards automatically and almost evenly over multiple GPUs.

Any support would be appreciated.

Thanks!

Well, it sounds like you want to use a pipeline parallel setup. You need to account for sending activations from the encoder on device0 to the decoder on device1 in addition to placing the parameters there.

FSDP and PipPy (our Pipeline Parallel solution) are not completely ready to compose together. We’re working on updates to PipPy and a new per-parameter FSDP implementation that will compose together nicely.

1 Like

Hi Will,

So do you mean implementing FSDP in a model parallel setup is currently not possible in PyTorch?