Recently I am running into the amp issue when using bfloat16 dtype in pytorch amp auto_cast mode. I am trying to train the SAM2 locally, but I find it is hard to fit it in bfloat16.
I searched in the pytorch issues that I found one issue may related to my scenario, that is bfloat16 auto_cast runtime error. Despite that the issue had been closed, the code displayed in the issue still cause an error in my environment (i.e., pytorch 2.4.0, 3090 GPU, cuda 12.1).
However, I cannot handle this issue, and I have checked the PR which aims to solve this issue. The issue should have been gone but I still encounter the same error.
So could anybody save me