@autocast drammatically decrease perfomance

In searching of ways to speed up training of my CNN U-Net base model I found that there is some autocast features in torch in torch.cuda.amp like ‘autocast’. If I understood right the conception is that Torch can automatically select appropriate tensor dtype in a runtime. So, I tried to decorate my ‘forward’ method with @autocast():

       ...
    @autocast()
    def forward(self, x):
     ...

Unfortinatelly this led to unpredicted behavior of model in training time : train loss decrease from 0.7 to 0.4 and then stagnate. All weights in model became 0 (i guess) because there is black picture as prediction… Without @autocast loss decreases from 0.7 to 0.01 in several epochs and perform well (but training is slow)

I’m wonder what can cause such behavior?

ps: My model got FloatTensor (X) and LongTensor (y in my case basically 0-1 mask). Previously I tried to cast input to HalfTensor but this gave no boost in speed (strange)
pps: Time per step of training ~40 millions params take 5-6 seconds when using one gtx3070ti. I guess it is not really good result so I think there is some weakness in my code

I don’t understand AMP very well, but I think you have to do gradient scaling as well, as described here.

I will test it , thank you!