FP8 training with torchao but without torchtitan

Hi, I have a training script that trains a customize LLM with FSDP (I’m using PyTorch 2.2, so I cannot utilize FSDP2 for now).

Now I want to use FP8 to accelerate my training procedure, but torchtitan seems only support Llama and FSDP2, then I tried these below:

from torchao import convert_model_to_float8_training
 ...
model = build_model(config_file)
convert_model_to_float8_training(model)
model = FSDP(model, **kwargs)
dataloader = get_dataloader()
optimizer = optim.AdamW(
    model.parameters(),
    lr=train_config.lr,
    weight_decay=train_config.weight_decay,
)
train(model, dataloder, optimizer)

I tested FP8 Matmul with torch._scaled_mm and this operator runs fine, but the FP8 trainnng is hanging and I cannot figured out why…

Does anyone has any suggestions?