I have PyTorch code that requires input of bfloat16, but some calculations need to be converted to float32. After the calculation is complete, it will be converted back to bfloat16 and calculated with the value of another bfloat16. This code runs normally, but if I try to convert to onnx and run it through onnxruntime, it will prompt ‘Could not find an implementation for MatMul (13) node with name’. What should I do
and this is my code
class Qwen2RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
Qwen2RMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
self .test1 = torch.full([5,5],1,dtype=torch.bfloat16)
self .test2 = torch.full([5,5],1,dtype=torch.float32)
def forward(self, hidden_states,kk = 0):
#this input type was torch.bfloat16
input_dtype = hidden_states.dtype
# if not (input_dtype == torch.float32 or input_dtype == torch.float16 or input_dtype == torch.bfloat16 or input_dtype == torch.float64):
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
rsqrt = torch.rsqrt(variance + self.variance_epsilon)
hidden_states = hidden_states * rsqrt
return self.weight * hidden_states.to(input_dtype)