Bfloat16 dtype runtime error in amp auto_cast enable mode

Recently I am running into the amp issue when using bfloat16 dtype in pytorch amp auto_cast mode. I am trying to train the SAM2 locally, but I find it is hard to fit it in bfloat16.

I searched in the pytorch issues that I found one issue may related to my scenario, that is bfloat16 auto_cast runtime error. Despite that the issue had been closed, the code displayed in the issue still cause an error in my environment (i.e., pytorch 2.4.0, 3090 GPU, cuda 12.1).

However, I cannot handle this issue, and I have checked the PR which aims to solve this issue. The issue should have been gone but I still encounter the same error.

So could anybody save me :slight_smile:

Hi, everyone, after carefully checking the date and time of the issue, I found the issue is opened in August, which is newer than the release date of pytorch 2.4.0.

So I tried to upgrade my pytorch to the latest pytorch 2.5.0, then the issue is solved.