# F.conv2D_dtype problem

Problem: F.conv2d(), input dtype is bfloat16, dtype of weight and bias are also bfloat16, but the output is float16

code:
def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]):
# print(f’pytorch forward 2 = {input}‘)
weight, bias, self.stride,
_pair(0), self.dilation, self.groups)

``````    print(f'---------------------================================')
# print(f'input = {input}, weight = {weight}, bias = {bias}, stride = {self.stride}, padding = {self.padding}, dilation = {self.dilation}, groups = {self.groups}')
print(f'stride = {self.stride}, padding = {self.padding}, dilation = {self.dilation}, groups = {self.groups}')
if torch.isnan(input).any() or torch.isinf(input).any():
raise ValueError("input contain NaN or Inf values")
if torch.isnan(weight).any() or torch.isinf(weight).any():
raise ValueError("weight contain NaN or Inf values")
print(f'[\$\$ debug dtype]: input_dtype_shape = {input.dtype}_{input.shape}, \
weight_dtype_shape = {weight.dtype}_{weight.shape}, \
bias_dtype_shape = {bias.dtype}_{bias.shape}')
return F.conv2d(input, weight, bias, self.stride,

def forward(self, input: Tensor) -> Tensor:
temp_v = self._conv_forward(input, self.weight, self.bias)
print(f'[pytorch debut --- temp_v.dtype] = {temp_v.dtype}_{temp_v.shape}') # float16
return temp_v
``````

torch version: torch 2.3.0+cu121
device: Nivida4090

I cannot reproduce the issue using `torch==2.3.1+cu121`:

``````>>> import torch
>>> torch.__version__
'2.3.1+cu121'
>>> x = torch.randn(1, 3, 224, 224, device="cuda", dtype=torch.bfloat16)
>>> import torch.nn as nn
>>> conv = nn.Conv2d(3, 3, 3, 1, 1).cuda().bfloat16()
>>> out = conv(x)
>>> print(out.dtype)
torch.bfloat16
``````