I’m practicing FSDP in PyTorch 2.6, and I’m running 2 servers for it.
My code works well in standalone mode, but in multiple node,
cannot load the checkpoint files.
The “.metadata” file is written in one node (node rank 0),
so the other node (node rank 1) cannot find “.metadata” file and occurs an error.
How can I resolve it?
In addition, my world size is 8 (4 GPU in each node),
so there are 8 files (“__0_0.distcp” - “__3_0.distcp” in node 0, and “__4_0.distcp” - “__7_0.distcp” in node 1).
In case of malfunction of one node, such as only “__0_0.distcp” to “__3_0.distcp” are accessible, how can I resume training?
Can I save full state checkpoint (like “???.pth”, not multiple “???.distcp”) for each node, and load and scatter to every processes?
My source code is as below
# Cell 1
import random
import os
import functools
import torch
from torch import nn
from torch.utils.data import DataLoader
import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
CPUOffload,
MixedPrecision,
BackwardPrefetch,
ShardingStrategy,
wrap
)
import torch.distributed.checkpoint as dcp
from torch.distributed.checkpoint.stateful import Stateful
from torch.distributed.checkpoint.state_dict import (
get_state_dict,
set_state_dict,
get_model_state_dict,
StateDictOptions
)
import torchvision
from torchvision.transforms import ToTensor
# Cell 2
# -------- Excution time and VRAM usage estimation -------- #
start_event = torch.cuda.Event(enable_timing = True)
end_event = torch.cuda.Event(enable_timing = True)
def initialize_cuda_performace_record():
torch.cuda.init()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
start_event.record()
def get_cuda_performace_record():
end_event.record()
torch.cuda.synchronize()
excution_time = start_event.elapsed_time(end_event) / 1000
peak_VRAM_usage = torch.cuda.max_memory_allocated()
return excution_time, peak_VRAM_usage
# Cell 3
# -------- Dataset and data loader -------- #
def get_dataset(train, download):
dataset = torchvision.datasets.FashionMNIST(root = 'data', train = train, download = download, transform = ToTensor())
return dataset
def get_data_loader(distributed, dataset, mini_batch_size, shuffle):
if distributed:
data_loader = DataLoader(dataset, batch_size = mini_batch_size, pin_memory = True, sampler = DistributedSampler(dataset, shuffle = shuffle))
else:
data_loader = DataLoader(dataset, batch_size = mini_batch_size, pin_memory = True, shuffle = shuffle)
return data_loader
# Cell 4
# -------- ANN architecture ------ #
class NeuralNetwork(nn.Module):
def parameter_initializer(self, layer):
if hasattr(layer, 'weight') and hasattr(layer, 'bias'):
torch.nn.init.xavier_normal_(layer.weight)
torch.nn.init.zeros_(layer.bias)
def __init__(self):
super().__init__()
self.cnn_stack1 = nn.Sequential(
nn.Conv2d(in_channels = 1, out_channels = 4, kernel_size = 5, stride = 2, padding = 2, padding_mode = 'zeros'),
nn.GELU(),
nn.Conv2d(in_channels = 4, out_channels = 8, kernel_size = 5, stride = 2, padding = 1, padding_mode = 'zeros'),
nn.GELU(),
nn.Conv2d(in_channels = 8, out_channels = 16, kernel_size = 5, stride = 2, padding = 1, padding_mode = 'zeros'),
nn.GELU()
)
self.cnn_stack2 = nn.Sequential(
nn.Conv2d(in_channels = 1, out_channels = 4, kernel_size = 5, stride = 2, padding = 2, padding_mode = 'zeros'),
nn.GELU(),
nn.Conv2d(in_channels = 4, out_channels = 8, kernel_size = 5, stride = 2, padding = 2, padding_mode = 'zeros'),
nn.GELU()
)
self.linear_stack1 = nn.Sequential(
nn.Flatten(start_dim = 1),
nn.Linear(16 * 2 * 2, 32)
)
self.linear_stack2 = nn.Sequential(
nn.Flatten(start_dim = 1),
nn.Linear(8 * 7 * 7, 128),
nn.Tanh(),
nn.Linear(128, 32)
)
self.softmax_stack = nn.Sequential(
nn.Linear(32, 10),
nn.Softmax(dim = 1)
)
for module in self.modules():
module.apply(self.parameter_initializer)
def forward(self, x):
x1 = self.cnn_stack1(x)
x1 = self.linear_stack1(x1)
x2 = self.cnn_stack2(x)
x2 = self.linear_stack2(x2)
x = x1 + x2
y = self.softmax_stack(x)
return y
# Cell 5
class DistributeCheckpoint(Stateful):
def __init__(self, fsdp_model, optimizer, scheduler):
self.fsdp_model = fsdp_model
self.optimizer = optimizer
self.scheduler = scheduler
self.state_dict_option = StateDictOptions(
full_state_dict = True,
cpu_offload = True,
broadcast_from_rank0 = True
)
def state_dict(self):
model_state, optimizer_state = get_state_dict(
self.fsdp_model,
self.optimizer,
options = self.state_dict_option)
checkpoint_state = {
'model_state' : model_state,
'optimizer_state' : optimizer_state,
'scheduler_state' : self.scheduler.state_dict()
}
return checkpoint_state
def load_state_dict(self, checkpoint_state):
set_state_dict(
self.fsdp_model,
self.optimizer,
model_state_dict = checkpoint_state['model_state'],
optim_state_dict = checkpoint_state['optimizer_state'],
options = self.state_dict_option
)
self.scheduler.load_state_dict(checkpoint_state['scheduler_state'])
def save_checkpoint(fsdp_model, optimizer, scheduler):
checkpoint = {
'checkpoint' : DistributeCheckpoint(fsdp_model, optimizer, scheduler)
}
dcp.save(checkpoint, checkpoint_id = 'fsdp_checkpoint')
def load_checkpoint(fsdp_model, optimizer, scheduler):
checkpoint = {
'checkpoint' : DistributeCheckpoint(fsdp_model, optimizer, scheduler)
}
dcp.load(checkpoint, checkpoint_id = 'fsdp_checkpoint')
# -------- ANN Training -------- #
def train(data_loader, fsdp_model, optimizer, accumulation_number = 1):
local_rank = int(os.environ['LOCAL_RANK'])
distributed_loss = torch.zeros(2).to(local_rank)
fsdp_model.train()
for mini_batch_index, (x, t) in enumerate(data_loader):
x = x.to(local_rank)
y = fsdp_model(x)
t = t.to(y.device)
loss = torch.nn.functional.cross_entropy(y, t)
loss.backward()
if mini_batch_index % accumulation_number == 0:
torch.nn.utils.clip_grad_norm_(fsdp_model.parameters(), 1e-1)
optimizer.step()
optimizer.zero_grad()
mini_batch_size = x.shape[0]
distributed_loss[0] += loss.item() * mini_batch_size
distributed_loss[1] += mini_batch_size
dist.all_reduce(distributed_loss, op = dist.ReduceOp.SUM)
average_loss = distributed_loss[0] / distributed_loss[1]
return average_loss
def training_loop(dataset, mini_batch_size, max_epoch, checkpoint_interval, accumulation_number = 1):
world_size = int(os.environ['WORLD_SIZE'])
global_rank = int(os.environ['RANK'])
local_rank = int(os.environ['LOCAL_RANK'])
if local_rank == 0:
initialize_cuda_performace_record()
training_data_loader = get_data_loader(distributed = True, dataset = dataset, mini_batch_size = mini_batch_size, shuffle = True)
model = NeuralNetwork().to(local_rank)
torch.cuda.set_device(local_rank)
fsdp_model = FSDP(
model,
auto_wrap_policy = functools.partial(wrap.size_based_auto_wrap_policy, min_num_params = 512),
device_id = torch.cuda.current_device(),
cpu_offload = CPUOffload(offload_params = True),
mixed_precision = MixedPrecision(param_dtype = torch.bfloat16, reduce_dtype = torch.bfloat16, buffer_dtype = torch.bfloat16),
backward_prefetch = BackwardPrefetch.BACKWARD_PRE,
sharding_strategy = ShardingStrategy.SHARD_GRAD_OP,
)
optimizer = torch.optim.Adam(fsdp_model.parameters(), lr = 1e-2, betas = (0.9, 0.999), weight_decay = 1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size = 1, gamma = 0.5)
current_epoch = 0
if os.path.exists('fsdp_checkpoint'):
load_checkpoint(fsdp_model, optimizer, scheduler)
current_epoch = scheduler.last_epoch
if local_rank == 0:
print(f'Resuming training from checkpoint at epoch {current_epoch + 1}\n' +
'\n',
end = ''
)
dist.barrier()
for t in range(current_epoch, max_epoch):
training_data_loader.sampler.set_epoch(t)
print(f'Worker {global_rank + 1} / {world_size} begins Epoch {t + 1 :> 3d} / {max_epoch}\n', end = '')
training_loss = train(training_data_loader, fsdp_model, optimizer, accumulation_number)
scheduler.step()
if local_rank == 0:
print(f' Training average loss: {training_loss :>8f}\n', end = '')
if t + 1 < max_epoch:
print('\n', end = '')
dist.barrier()
if (t + 1) % checkpoint_interval == 0 and (t + 1) != max_epoch:
save_checkpoint(fsdp_model, optimizer, scheduler)
if local_rank == 0:
print(f'Saved training checkpoint at {t + 1} epochs under "fsdp_checkpoint"\n' +
'\n',
end = ''
)
dist.barrier()
if global_rank == 0:
excution_time, peak_VRAM_usage = get_cuda_performace_record()
print('-------------------------------\n' +
f'Training with DDP for {max_epoch} epochs:\n' +
f' Execution time: {excution_time :>0.4} sec\n' +
f' Peak VRAM usage: {peak_VRAM_usage / (1024 ** 2) :>,.2f} MB\n' +
'-------------------------------\n',
end = ''
)
state_dict_option = StateDictOptions(
full_state_dict = True,
cpu_offload = True,
broadcast_from_rank0 = True
)
model_state = get_model_state_dict(fsdp_model, options = state_dict_option)
if global_rank == 0:
torch.save(model_state, 'model.pth')
print('Saved PyTorch ANN parameters to model.pth\n' +
'-------------------------------\n',
end = ''
)
# Cell 6
# -------- ANN test and inference -------- #
def test(device, data_loader, model):
total_loss = 0
total_correct = 0
model.eval()
with torch.no_grad():
for x, t in data_loader:
x = x.to(device)
y = model(x)
t = t.to(y.device)
loss = torch.nn.functional.cross_entropy(y, t)
mini_batch_size = x.shape[0]
total_loss += loss.item() * mini_batch_size
total_correct += (y.argmax(dim = 1) == t).type(torch.float).sum().item()
dataset_size = len(data_loader.dataset)
average_loss = total_loss / dataset_size
accuracy = total_correct / dataset_size
return average_loss, accuracy
def inference(device, data, model):
x = data.view(1, 1, 28, 28)
model.eval()
with torch.no_grad():
x = x.to(device)
y = model(x)
return y
# Cell 7
# -------- Main function ------ #
if __name__ == '__main__':
number_of_GPU = torch.cuda.device_count()
world_size = int(os.environ['WORLD_SIZE'])
local_world_size = int(os.environ['LOCAL_WORLD_SIZE'])
global_rank = int(os.environ['RANK'])
local_rank = int(os.environ['LOCAL_RANK'])
if local_rank == 0:
print(f'PyTorch version: {torch.__version__}\n' +
'-------------------------------\n' +
f'Number of GPU: {number_of_GPU}\n' +
f'World size: {world_size}\n' +
f'Local world size: {local_world_size}\n' +
'-------------------------------\n',
end = ''
)
if local_world_size > number_of_GPU:
if local_rank == 0:
print(f'Need more GPUs in this node\n' +
f' Number of GPU in this node: {number_of_GPU}\n' +
f' This node needs: {local_world_size}\n' +
'-------------------------------\n',
end = '')
exit()
# ---- Training ---- #
dist.init_process_group(backend = 'nccl')
if local_rank == 0:
training_dataset = get_dataset(train = True, download = True)
dist.barrier()
if local_rank != 0:
training_dataset = get_dataset(train = True, download = False)
training_mini_batch_size = 64
max_epoch = 10
accumulation_number = 4
checkpoint_interval = 5
training_loop(training_dataset, training_mini_batch_size, max_epoch, checkpoint_interval, accumulation_number)
dist.destroy_process_group()
# ---- Test and inference ---- #
if global_rank == 0:
infernece_device = 'cuda'
test_dataset = get_dataset(train = False, download = True)
test_mini_batch_size = 64
test_data_loader = get_data_loader(distributed = False, dataset = test_dataset, mini_batch_size = test_mini_batch_size, shuffle = False)
model = NeuralNetwork().to(infernece_device)
model.load_state_dict(torch.load('model.pth', weights_only = True))
test_loss, test_accuracy = test(infernece_device, test_data_loader, model)
print('Test performance\n' +
f' Average loss: {test_loss :>8f}\n' +
f' Accuracy: {(100 * test_accuracy) :>0.2f}%\n' +
'-------------------------------\n',
end = ''
)
test_sample_index = random.randint(0, len(test_dataset) - 1)
x, t = test_dataset[test_sample_index]
y = inference(infernece_device, x, model)
classes = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
predicted, actual = classes[y.argmax(dim = 1)], classes[t]
print('Random sample inference\n' +
f' Predicted: "{predicted}", Actual: "{actual}"\n' +
'-------------------------------\n',
end = ''
)