Handling module alias in FSDP

Hi,

I’m trying to train a model that has module alias with FSDP. I got the following error when loading the checkpoint file.

  ......
  File "***/checkpoints.py", line 67, in load_model_checkpoint
    state_dict = torch.load(checkpoint_path, map_location="cpu")
  File "***/lib/python3.8/site-packages/torch/serialization.py", line 1014, in load
    return _load(opened_zipfile,
  File "***/lib/python3.8/site-packages/torch/serialization.py", line 1422, in _load
    result = unpickler.load()
  File "***/lib/python3.8/site-packages/torch/_utils.py", line 202, in _rebuild_tensor_v2
    tensor = _rebuild_tensor(storage, storage_offset, size, stride)
  File "***/lib/python3.8/site-packages/torch/_utils.py", line 181, in _rebuild_tensor
    return t.set_(storage._untyped_storage, storage_offset, size, stride)
RuntimeError: Trying to resize storage that is not resizable

Here’s a minimum example to reproduce the error.

import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

DIM = 16


class Model(torch.nn.Module):
    def __init__(self, dim=512) -> None:
        super().__init__()
        self.linear1 = torch.nn.Linear(dim, dim)
        self.linear2 = torch.nn.Linear(dim, dim)
        self.linear3 = torch.nn.Linear(dim, dim)

        self.linear_alias = self.linear3


def setup(rank, world_size):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "14355"
    torch.cuda.set_device(rank)
    # initialize the process group
    dist.init_process_group("nccl", rank=rank, world_size=world_size)


def cleanup():
    dist.destroy_process_group()


def demo_basic(rank, world_size):
    setup(rank, world_size)

    model = Model(dim=DIM).to(rank)
    model = FSDP(model, device_id=rank)

    state_dict = model.state_dict()
    if rank == 0:
        torch.save(state_dict, "model.pt")

    cleanup()


if __name__ == "__main__":
    world_size = torch.cuda.device_count()
    mp.spawn(demo_basic, args=(world_size,), nprocs=world_size, join=True)
    chkpt = torch.load("model.pt", map_location="cpu")

Should we use module alias at all? If I must use alias, how should I handle them in FSDP?

I didn’t find much information about using alias modules with FSDP online. Would anyone please help me with some suggestions? Thank you!