Hi, I have a training script that trains a customize LLM with FSDP (I’m using PyTorch 2.2, so I cannot utilize FSDP2 for now).
Now I want to use FP8 to accelerate my training procedure, but torchtitan seems only support Llama and FSDP2, then I tried these below:
from torchao import convert_model_to_float8_training
...
model = build_model(config_file)
convert_model_to_float8_training(model)
model = FSDP(model, **kwargs)
dataloader = get_dataloader()
optimizer = optim.AdamW(
model.parameters(),
lr=train_config.lr,
weight_decay=train_config.weight_decay,
)
train(model, dataloder, optimizer)
I tested FP8 Matmul with torch._scaled_mm
and this operator runs fine, but the FP8 trainnng is hanging and I cannot figured out why…
Does anyone has any suggestions?