batch norm can receive half tensor and also return half tensor in normal cases,
but after tracing, and under torch.no_grad(), it return float instead of half (see img)
bn_layer = torch.nn.BatchNorm2d(10).cuda().float()
x = torch.rand(1,10,10,10).cuda().half()
o1 = bn_layer(x)
print(o1.dtype)
with torch.no_grad():
trace = torch.jit.trace(bn_layer, torch.rand(1,10,10,10).cuda().half())
with torch.no_grad():
o2 = trace(x)
print(o2.dtype)
o3 = trace(x)
print(o3.dtype)