Great, something I have mention is that based on stacktrace error, it seems you are using Nvidia Apex which does something related to mixed-precision, so explicitly, torch.half datatype has been used in the code. The solution I suggested was based on code, but it may affect the logic (I am not sure) although it should not as .float converts to 32-bit floating point meanwhile torch.half only needs 16 bits. Maybe somewhere else you can convert back to half to save memory or speed up.