Hello,
I need to implement FSDP in a model parallel setup. I want my encoder to run on a single GPU and the decoder to run on another GPU while harnessing the memory saving options, optimization options, and distributed training options that I get with FSDP.
I have a computer with 4 GPUs. I am running the following without a model parallel setup with no errors.
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
from tqdm import tqdm
import os
import torch.distributed as dist
from functools import partial
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
MixedPrecision,
BackwardPrefetch,
ShardingStrategy,
)
from torch.distributed.fsdp.wrap import _module_wrap_policy
# Define model
class NeuralNetwork(nn.Module):
def __init__(self):
super().__init__()
self.encoder = nn.Sequential(
nn.Conv2d(1, 4, 3, padding="same"),
nn.MaxPool2d(2),
nn.Conv2d(4, 8, 3, padding="same"),
nn.MaxPool2d(2),
)
self.decoder = nn.Sequential(
nn.Flatten(),
nn.Linear(7 * 7 * 8, 512),
nn.ReLU(),
nn.Linear(512, 10),
)
def forward(self, x):
x = self.encoder(x)
logits = self.decoder(x)
return logits
def train(dataloader, model, loss_fn, optimizer, rank):
model.train()
with tqdm(
total=len(dataloader), postfix={"loss": "undefined"}, disable=rank != 0
) as pbar:
for X, y in dataloader:
# Compute prediction error
pred = model(X)
y = y.to(pred.device)
loss = loss_fn(pred, y)
# Backpropagation
loss.backward()
optimizer.step()
optimizer.zero_grad()
pbar.set_postfix({"loss": loss.cpu().item()})
pbar.update(1)
def test(dataloader, model, loss_fn, rank):
size = len(dataloader.dataset)
num_batches = len(dataloader)
model.eval()
loss_correct_batches = torch.tensor([0, 0, 0]).to(torch.float32)
with torch.no_grad():
for X, y in dataloader:
pred = model(X)
y = y.to(pred.device)
loss_correct_batches[0] += loss_fn(pred, y).cpu().item()
loss_correct_batches[1] += (
(pred.argmax(1) == y).type(torch.float).sum().cpu().item()
)
loss_correct_batches[2] += 1
loss_correct_batches = loss_correct_batches.to(pred.device)
dist.all_reduce(loss_correct_batches)
if rank == 0:
test_loss, correct, num_batches = loss_correct_batches.cpu().tolist()
test_loss /= num_batches
correct /= size
print(
f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n"
)
def get_policies():
auto_wrap_policy = partial(
_module_wrap_policy,
module_classes={nn.Linear, nn.Conv2d},
)
sharding_strategy = [
ShardingStrategy.FULL_SHARD,
ShardingStrategy.SHARD_GRAD_OP,
ShardingStrategy.NO_SHARD,
][0]
prefetch_policy = [
None,
BackwardPrefetch.BACKWARD_POST,
BackwardPrefetch.BACKWARD_PRE, # 13% speed up, 0.59% peak memory increase
][2]
mp_policy = MixedPrecision(
param_dtype=torch.float16, # Param precision
reduce_dtype=torch.float16, # Gradient communication precision
buffer_dtype=torch.float16, # Buffer precision
)
return auto_wrap_policy, sharding_strategy, prefetch_policy, mp_policy
if __name__ == "__main__":
local_rank = int(os.environ["LOCAL_RANK"])
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
device_count_per_proc = 2
devices = [
i + local_rank * device_count_per_proc for i in range(device_count_per_proc)
]
dist.init_process_group("nccl")
# Download training data from open datasets.
training_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor(),
)
# Download test data from open datasets.
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=ToTensor(),
)
batch_size = 64
# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)
model = NeuralNetwork()
auto_wrap_policy, sharding_strategy, prefetch_policy, mp_policy = get_policies()
torch.cuda.set_device(devices[0])
model = FSDP(
model,
auto_wrap_policy=auto_wrap_policy,
mixed_precision=mp_policy,
sharding_strategy=sharding_strategy,
backward_prefetch=prefetch_policy,
device_id=devices[0],
)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
epochs = 5
for t in range(epochs):
print(f"Epoch {t+1}\n-------------------------------")
train(train_dataloader, model, loss_fn, optimizer, rank)
test(test_dataloader, model, loss_fn, rank)
print("Done!")
dist.destroy_process_group()
I am invoking this by
torchrun --nnodes 1 --nproc_per_node 2 ./pt-basics-dist.py
But when I try to do this in a model parallel setup while initializing the encoder and decoder with different devices as follows, I get an “Exception raised from c10_cuda_check_implementation at …/c10/cuda/CUDAException.cpp:44”.
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
from tqdm import tqdm
import os
import torch.distributed as dist
from functools import partial
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
MixedPrecision,
BackwardPrefetch,
ShardingStrategy,
)
from torch.distributed.fsdp.wrap import _module_wrap_policy
# Define model
class NeuralNetwork(nn.Module):
def __init__(self):
super().__init__()
self.encoder = nn.Sequential(
nn.Conv2d(1, 4, 3, padding="same"),
nn.MaxPool2d(2),
nn.Conv2d(4, 8, 3, padding="same"),
nn.MaxPool2d(2),
)
self.decoder = nn.Sequential(
nn.Flatten(),
nn.Linear(7 * 7 * 8, 512),
nn.ReLU(),
nn.Linear(512, 10),
)
def forward(self, x):
x = self.encoder(x)
logits = self.decoder(x)
return logits
def train(dataloader, model, loss_fn, optimizer, rank):
model.train()
with tqdm(
total=len(dataloader), postfix={"loss": "undefined"}, disable=rank != 0
) as pbar:
for X, y in dataloader:
# Compute prediction error
pred = model(X)
y = y.to(pred.device)
loss = loss_fn(pred, y)
# Backpropagation
loss.backward()
optimizer.step()
optimizer.zero_grad()
pbar.set_postfix({"loss": loss.cpu().item()})
pbar.update(1)
def test(dataloader, model, loss_fn, rank):
size = len(dataloader.dataset)
num_batches = len(dataloader)
model.eval()
loss_correct_batches = torch.tensor([0, 0, 0]).to(torch.float32)
with torch.no_grad():
for X, y in dataloader:
pred = model(X)
y = y.to(pred.device)
loss_correct_batches[0] += loss_fn(pred, y).cpu().item()
loss_correct_batches[1] += (
(pred.argmax(1) == y).type(torch.float).sum().cpu().item()
)
loss_correct_batches[2] += 1
loss_correct_batches = loss_correct_batches.to(pred.device)
dist.all_reduce(loss_correct_batches)
if rank == 0:
test_loss, correct, num_batches = loss_correct_batches.cpu().tolist()
test_loss /= num_batches
correct /= size
print(
f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n"
)
def get_policies():
auto_wrap_policy = partial(
_module_wrap_policy,
module_classes={nn.Linear, nn.Conv2d},
)
sharding_strategy = [
ShardingStrategy.FULL_SHARD,
ShardingStrategy.SHARD_GRAD_OP,
ShardingStrategy.NO_SHARD,
][0]
prefetch_policy = [
None,
BackwardPrefetch.BACKWARD_POST,
BackwardPrefetch.BACKWARD_PRE, # 13% speed up, 0.59% peak memory increase
][2]
mp_policy = MixedPrecision(
param_dtype=torch.float16, # Param precision
reduce_dtype=torch.float16, # Gradient communication precision
buffer_dtype=torch.float16, # Buffer precision
)
return auto_wrap_policy, sharding_strategy, prefetch_policy, mp_policy
if __name__ == "__main__":
local_rank = int(os.environ["LOCAL_RANK"])
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
device_count_per_proc = 2
devices = [
i + local_rank * device_count_per_proc for i in range(device_count_per_proc)
]
dist.init_process_group("nccl")
# Download training data from open datasets.
training_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor(),
)
# Download test data from open datasets.
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=ToTensor(),
)
batch_size = 64
# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)
model = NeuralNetwork() # .to(device)
auto_wrap_policy, sharding_strategy, prefetch_policy, mp_policy = get_policies()
torch.cuda.set_device(devices[0])
model.encoder = FSDP(
model.encoder,
auto_wrap_policy=auto_wrap_policy,
mixed_precision=mp_policy,
sharding_strategy=sharding_strategy,
backward_prefetch=prefetch_policy,
device_id=devices[0],
)
torch.cuda.set_device(devices[1])
model.decoder = FSDP(
model.decoder,
auto_wrap_policy=auto_wrap_policy,
mixed_precision=mp_policy,
sharding_strategy=sharding_strategy,
backward_prefetch=prefetch_policy,
device_id=devices[1],
)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
epochs = 5
for t in range(epochs):
print(f"Epoch {t+1}\n-------------------------------")
train(train_dataloader, model, loss_fn, optimizer, rank)
test(test_dataloader, model, loss_fn, rank)
print("Done!")
dist.destroy_process_group()
I think I am using FSDP in an unintended method.
I would also like to know if there is an easy method to distribute the shards automatically and almost evenly over multiple GPUs.
Any support would be appreciated.
Thanks!