Running the follow code with torchrun --nnodes 1 --nproc-per-node 2 small.py
causes the compiler to fail with “RuntimeError: assigned grad has data of a different size”. Any thing obvious I am doing wrong here?
Code:
import torch
from torch import nn
from torch.nn import functional as F
import os
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
class SwiGLU(nn.Module):
def forward(self, x):
x, gate = x.chunk(2, dim=-1)
return F.silu(gate) * x
class SwiGLUMLP(nn.Module):
def __init__(self, model_dim, device=None, dtype=None):
super().__init__()
hidden_dim = int(round(model_dim * 21 / 8))
self.input = nn.Linear(model_dim, hidden_dim * 2, bias=False, device=device, dtype=dtype)
self.output = nn.Linear(hidden_dim, model_dim, bias=False, device=device, dtype=dtype)
self.act = SwiGLU()
def forward(self, x):
x = self.input(x)
x = self.act(x)
return self.output(x)
if __name__ == "__main__":
from torch import _dynamo
_dynamo.config.verbose=True
local_rank = int(os.environ['LOCAL_RANK'])
group = torch.distributed.init_process_group('nccl')
torch.cuda.set_device(local_rank)
model = SwiGLUMLP(32).to(f'cuda:{local_rank}')
model = FSDP(model, use_orig_params=True, device_id=local_rank)
compiled_model = torch.compile(model)
input = torch.randn(1,32, device=f'cuda:{local_rank}')
with torch.autocast('cuda'):
loss = model(input).sum()
loss.backward()
with torch.autocast('cuda'):
loss = compiled_model(input).sum()
loss.backward()