Error while using 16-bit floats (.half())

Thanks. I took your advice here and used torch.autocast, as following:

with torch.autocast('cuda'):
    ... (my code) ...
1 Like