Keeping optimizer states in FP32

Based on Rae et al., 2021, Figure A7, training LLMs in BF16 could hurt its performance, and they recommend keeping optimizer states in FP32 (model/gradients can stay in BF16).

I was wondering how to implement such a behavior in torch. Especially when combined with FSDP. For instance for bf16, and adamW, can I simply cast exp_avg, and exp_avg_sq from optimizer.state_dict()?

Do you know any existing implementation of this?