Batch norm mix precision jit bug

batch norm can receive half tensor and also return half tensor in normal cases,

but after tracing, and under torch.no_grad(), it return float instead of half (see img)

bn_layer = torch.nn.BatchNorm2d(10).cuda().float()
x = torch.rand(1,10,10,10).cuda().half()
o1 = bn_layer(x)
print(o1.dtype)

with torch.no_grad():
    trace = torch.jit.trace(bn_layer, torch.rand(1,10,10,10).cuda().half())
    
with torch.no_grad():
    o2 = trace(x)
print(o2.dtype)

o3 = trace(x)
print(o3.dtype)