Precision issue after converting to CoreML

Hello, I am trying to convert my Pytorch model to CoreML format. However, the outputs of pytorch model and the converted CoreML model do not match well.

Here is the code of a minimal example:

import numpy as np 
import torch
import torch.nn as nn

class Sample_model(torch.nn.Module) :
    def __init__(self, n_classes):
        super().__init__()
        self.fc1 = nn.Linear(2, n_classes)
        
    def forward(self, x):
        return self.fc1(x)
    

# initialize a model with 3 classes
model_fixed =  Sample_model(3)

# create a random input tensor of size (batch_size=2, feature_size=2) as an example
features = torch.tensor(np.array([[1, 2], [3, 4]], dtype=np.float32))
# use the first example as an example input to trace
example_input_trace = features[:1]
traced_model = torch.jit.trace(model_fixed, example_input_trace)


import coremltools as ct

model_ct = ct.convert(
    traced_model,
    convert_to="mlprogram",
    inputs=[ct.TensorType(shape=example_input_trace.shape, dtype=np.float32)]
 )

# use the second example as an example input to predict
test_input = features[1:]

coreml_pred = model_ct.predict({'x': test_input.numpy()})['linear_0']
original_torch_pred = model_fixed(test_input).detach().numpy()


print(original_torch_pred - coreml_pred)  
# output array I got for one run is ([[-0.00027072, -0.00030965, -0.00077653]], dtype=float32), which is not close to zero

np.allclose(original_torch_pred, coreml_pred)  
# return False, which is supposed to be True

My Pytorch version: 1.13.1
coremltools version: 7.0
My device: Apple M1 macbook
OS: macOS Monterey 12.5.1

Is it normal to see the discrepancy of that scale even for this toy model? Or did I miss something in the code?

Really appreciate it if someone can help.

Got 1 solution. Set the precision when using coremltools.convert function:

model_ct = ct.convert(
    traced_model,
    convert_to="mlprogram",
    compute_precision=ct.precision.FLOAT32,
    inputs=[ct.TensorType(shape=example_input_trace.shape, dtype=np.float32)]
 )