You could use AT_DISPATCH_FLOATING_TYPES_AND_HALF
to dispatch the code for the float16
type and use scalar_t
in the code (similar to e.g. this code).
Also note, that we recommend to use the native mixed-precision training utility via torch.cuda.amp
instead of apex/amp
now.