Bfloat16 training question

Yo, guys, I am looking to train a 3D CNN on my RTX3080ti and as I know it supports bfloat16, but I struggle to train a model. I tried casting everything straight to bfloat 16 and I tried using bfloat16 with AMP but it throws an error, it sounds like “ illegal memory access” and when I am just using fp32 everything is fine. Any tips on bfloat16 training and is it even viable? I have seen that pytorch lightning has some king of support, but I would prefer normal torch.

Could you post a minimal, executable code snippet to reproduce the issue as well as the output of python -m torch.utils.collect_env, please?

Yes, I will try to do it in a day or two)