FSDP + compile says "RuntimeError: assigned grad has data of a different size"

Running the follow code with torchrun --nnodes 1 --nproc-per-node 2 small.py causes the compiler to fail with “RuntimeError: assigned grad has data of a different size”. Any thing obvious I am doing wrong here?


import torch
from torch import nn
from torch.nn import functional as F
import os

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

class SwiGLU(nn.Module):
    def forward(self, x):
        x, gate = x.chunk(2, dim=-1)
        return F.silu(gate) * x

class SwiGLUMLP(nn.Module):
    def __init__(self, model_dim, device=None, dtype=None):
        hidden_dim = int(round(model_dim * 21 / 8))
        self.input = nn.Linear(model_dim, hidden_dim * 2, bias=False, device=device, dtype=dtype)
        self.output = nn.Linear(hidden_dim, model_dim, bias=False, device=device, dtype=dtype)
        self.act = SwiGLU()
    def forward(self, x):
        x = self.input(x)
        x = self.act(x)
        return self.output(x)
if __name__ == "__main__":
    from torch import _dynamo
    local_rank = int(os.environ['LOCAL_RANK'])

    group = torch.distributed.init_process_group('nccl')

    model = SwiGLUMLP(32).to(f'cuda:{local_rank}')
    model = FSDP(model, use_orig_params=True, device_id=local_rank)
    compiled_model = torch.compile(model)
    input = torch.randn(1,32, device=f'cuda:{local_rank}')

    with torch.autocast('cuda'):
        loss = model(input).sum()

    with torch.autocast('cuda'):
        loss = compiled_model(input).sum()

Huh, error is because of running the model before running the compiled one - must be confusing something in the compiler?