FSDP summon_full_parameter: unsharded size error

During training, I want to gather all sharded parameters of an FSDP model, and run validation on rank 0 only. The code is as follows:

import torch
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

class MyModel(torch.nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.fc1 = torch.nn.Linear(784, 128)
        self.fc2 = torch.nn.Linear(128, 10)

    def forward(self, x):
        x = torch.flatten(x, start_dim=1)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

def get_evaluation_dataloader():
    dataset = torch.utils.data.TensorDataset(
        torch.randn(100, 1, 28, 28), torch.randint(0, 10, (100,))
    )
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=32)
    return dataloader

def run_evaluation(model, dataloader, device):
    model.eval()
    total_correct = 0
    total_samples = 0
    with torch.no_grad():
        for data, target in dataloader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            pred = output.argmax(dim=1, keepdim=True)
            total_correct += pred.eq(target.view_as(pred)).sum().item()
            total_samples += target.size(0)
    
    accuracy = total_correct / total_samples
    print(f'Accuracy: {accuracy * 100:.2f}%')
    return accuracy

def main():
    # Initialize the distributed environment
    dist.init_process_group(backend='nccl', init_method='env://')
    
    # Set the device for the current process
    local_rank = dist.get_rank()
    torch.cuda.set_device(local_rank)
    device = torch.device("cuda", local_rank)
    
    # Define and wrap your model in FSDP
    model = MyModel().to(device)
    fsdp_model = FSDP(model)
    
    # Assume we have a function that returns our evaluation dataloader
    dataloader = get_evaluation_dataloader()

    # Use summon_full_params to gather the full parameters

    # Synchronize before entering summon_full_params context
    dist.barrier()

    with fsdp_model.summon_full_params(fsdp_model):
        if dist.get_rank() == 0:
            # Run evaluation only on rank 0
            run_evaluation(fsdp_model, dataloader, device)

    # Ensure synchronization after each epoch
    dist.barrier()

    # Cleanup
    dist.destroy_process_group()

if __name__ == "__main__":
    main()

I launch the code by
torchrun --standalone --nnodes=1 --nproc_per_node=2 main.py.

The evaluation completes. However it reports the following error:
Expects tensor to be unsharded with size torch.Size([101770]) but got torch.Size([50885]). It seems to be something wrong with the summon_full_parameter context manager. Any idea?

I am curious why you need to use summon_full_params to run eval? Is it because your data loader is only setup to load for eval on rank 0?

Could you consider using dummy data on nonzero ranks? I think you would get more efficient forward pass if you just ran a normal forward pass on all ranks for eval since then you can still overlap all-gather with compute.

The code I pasted is for proof-of-concept. In reality, my input is generated on rank 0 (by calling huggingface’s generate function). Of course, I can scatter the input to all ranks, but that causes additional overhead.

Hi @Jiaji_Huang , have found any solutions? I am working on a similar use case (using huggingface’s generate) and encountering the same issue

It’s been a while. I didn’t manage to solve the issue. Instead, I scatter the input from rank 0 to all ranks, then run evaluation.