During training, I want to gather all sharded parameters of an FSDP model, and run validation on rank 0 only. The code is as follows:
import torch
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
class MyModel(torch.nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.fc1 = torch.nn.Linear(784, 128)
self.fc2 = torch.nn.Linear(128, 10)
def forward(self, x):
x = torch.flatten(x, start_dim=1)
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
def get_evaluation_dataloader():
dataset = torch.utils.data.TensorDataset(
torch.randn(100, 1, 28, 28), torch.randint(0, 10, (100,))
)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32)
return dataloader
def run_evaluation(model, dataloader, device):
model.eval()
total_correct = 0
total_samples = 0
with torch.no_grad():
for data, target in dataloader:
data, target = data.to(device), target.to(device)
output = model(data)
pred = output.argmax(dim=1, keepdim=True)
total_correct += pred.eq(target.view_as(pred)).sum().item()
total_samples += target.size(0)
accuracy = total_correct / total_samples
print(f'Accuracy: {accuracy * 100:.2f}%')
return accuracy
def main():
# Initialize the distributed environment
dist.init_process_group(backend='nccl', init_method='env://')
# Set the device for the current process
local_rank = dist.get_rank()
torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)
# Define and wrap your model in FSDP
model = MyModel().to(device)
fsdp_model = FSDP(model)
# Assume we have a function that returns our evaluation dataloader
dataloader = get_evaluation_dataloader()
# Use summon_full_params to gather the full parameters
# Synchronize before entering summon_full_params context
dist.barrier()
with fsdp_model.summon_full_params(fsdp_model):
if dist.get_rank() == 0:
# Run evaluation only on rank 0
run_evaluation(fsdp_model, dataloader, device)
# Ensure synchronization after each epoch
dist.barrier()
# Cleanup
dist.destroy_process_group()
if __name__ == "__main__":
main()
I launch the code by
torchrun --standalone --nnodes=1 --nproc_per_node=2 main.py
.
The evaluation completes. However it reports the following error:
Expects tensor to be unsharded with size torch.Size([101770]) but got torch.Size([50885])
. It seems to be something wrong with the summon_full_parameter
context manager. Any idea?