Training on CPU with intel_extension_for_pytorch

model.train()
model, optimizer = ipex.optimize(model, optimizer=optimizer, dtype=torch.float)

I am running into below issue when trying to use the ipex optimize API on AMD EPYC 7R32

File “/opt/conda/envs/vpytorch/lib/python3.10/site-packages/intel_extension_for_pytorch/frontend.py”, line 563, in optimize
) = weight_prepack_with_ipex(
File “/opt/conda/envs/vpytorch/lib/python3.10/site-packages/intel_extension_for_pytorch/nn/utils/_weight_prepack.py”, line 541, in weight_prepack_with_ipex
opt_model, opt_optmizer, params_attr = convert_rec(
File “/opt/conda/envs/vpytorch/lib/python3.10/site-packages/intel_extension_for_pytorch/nn/utils/_weight_prepack.py”, line 537, in convert_rec
setattr(new_m, name, convert_rec(sub_m, optimizer, params_attr)[0])
File “/opt/conda/envs/vpytorch/lib/python3.10/site-packages/intel_extension_for_pytorch/nn/utils/_weight_prepack.py”, line 537, in convert_rec
setattr(new_m, name, convert_rec(sub_m, optimizer, params_attr)[0])
File “/opt/conda/envs/vpytorch/lib/python3.10/site-packages/intel_extension_for_pytorch/nn/utils/_weight_prepack.py”, line 537, in convert_rec
setattr(new_m, name, convert_rec(sub_m, optimizer, params_attr)[0])
[Previous line repeated 2 more times]
File “/opt/conda/envs/vpytorch/lib/python3.10/site-packages/intel_extension_for_pytorch/nn/utils/_weight_prepack.py”, line 535, in convert_rec
new_m = convert(m, optimizer, params_attr)
File “/opt/conda/envs/vpytorch/lib/python3.10/site-packages/intel_extension_for_pytorch/nn/utils/_weight_prepack.py”, line 499, in convert
param_wrapper.prepack(m, is_training)
File “/opt/conda/envs/vpytorch/lib/python3.10/site-packages/intel_extension_for_pytorch/nn/utils/_parameter_wrapper.py”, line 475, in prepack
self.linear_prepack(module, is_training)
File “/opt/conda/envs/vpytorch/lib/python3.10/site-packages/intel_extension_for_pytorch/nn/utils/_parameter_wrapper.py”, line 578, in linear_prepack
self.op_ctx = torch.ops.ipex_prepack.linear_prepack(
File “/opt/conda/envs/vpytorch/lib/python3.10/site-packages/torch/_ops.py”, line 692, in call
return self._op(*args, **kwargs or {})
RuntimeError: could not create a primitive descriptor for an inner product forward propagation primitive