F.conv2D_dtype problem

Problem: F.conv2d(), input dtype is bfloat16, dtype of weight and bias are also bfloat16, but the output is float16

code:
def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]):
# print(f’pytorch forward 2 = {input}‘)
print(f’padding mode = {self.padding_mode}’)
if self.padding_mode != ‘zeros’:
return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
weight, bias, self.stride,
_pair(0), self.dilation, self.groups)

    print(f'---------------------================================')
    # print(f'input = {input}, weight = {weight}, bias = {bias}, stride = {self.stride}, padding = {self.padding}, dilation = {self.dilation}, groups = {self.groups}')
    print(f'stride = {self.stride}, padding = {self.padding}, dilation = {self.dilation}, groups = {self.groups}')
    if torch.isnan(input).any() or torch.isinf(input).any():
        raise ValueError("input contain NaN or Inf values")
    if torch.isnan(weight).any() or torch.isinf(weight).any():
        raise ValueError("weight contain NaN or Inf values")
    print(f'[$$ debug dtype]: input_dtype_shape = {input.dtype}_{input.shape}, \
          weight_dtype_shape = {weight.dtype}_{weight.shape}, \
          bias_dtype_shape = {bias.dtype}_{bias.shape}')
    return F.conv2d(input, weight, bias, self.stride,
                    self.padding, self.dilation, self.groups)

def forward(self, input: Tensor) -> Tensor:
    temp_v = self._conv_forward(input, self.weight, self.bias)
    print(f'[pytorch debut --- temp_v.dtype] = {temp_v.dtype}_{temp_v.shape}') # float16
    return temp_v

torch version: torch 2.3.0+cu121
device: Nivida4090

I cannot reproduce the issue using torch==2.3.1+cu121:

>>> import torch
>>> torch.__version__
'2.3.1+cu121'
>>> x = torch.randn(1, 3, 224, 224, device="cuda", dtype=torch.bfloat16)
>>> import torch.nn as nn
>>> conv = nn.Conv2d(3, 3, 3, 1, 1).cuda().bfloat16()
>>> out = conv(x)
>>> print(out.dtype)
torch.bfloat16