infenrence with python and c++ got the different answer

I trained a mobilenetv3 model with cpu and test it. Although both python and c++ gives the right answer, but their output tensors were different.
python output tensors:

tensor([[-6.8434,  7.5787,  2.4852,  2.0013,  3.2962, -7.6010, 10.4773, -0.4331,
          1.4424,  0.0946, -1.4620, -1.4271, -3.0626, -0.4510, -6.5796, -3.5845,
         -8.1056]], grad_fn=<ViewBackward>)

c++ output tensors:

Columns 1 to 8 -5.9655   1.8631   2.1338   2.5918   1.9550  -7.2201  10.4710   1.2880
Columns 9 to 16 -1.9018   1.3437  -3.1746   2.3237  -0.1883   2.3990  -5.5839  -4.7663
Columns 17 to 17 -8.7208
[ CPUFloatType{1,17} ]

why is this happenning?

Besides, I also tested the inference speed of python and c++. pytorch costs 0.21s while libtorch costs 0.23s. What makes libtorch slower than pytorch? What can I do to solve this problem?
My python training, inference program and my c++ inference program are as follows:
python training:

class MobileNetV3_large(nn.Module):
    # (out_channels,kernel_size,exp_channels,stride,se,nl)
    cfg=[
        (16,3,16,1,False,'RE'),
        (24,3,64,2,False,'RE'),
        (24,3,72,1,False,'RE'),
        (40,5,72,2,True,'RE'),
        (40,5,120,1,True,'RE'),
        (40,5,120,1,True,'RE'),
        (80,3,240,2,False,'HS'),
        (80,3,200,1,False,'HS'),
        (80,3,184,1,False,'HS'),
        (80,3,184,1,False,'HS'),
        (112,3,480,1,True,'HS'),
        (112,3,672,1,True,'HS'),
        (160,5,672,2,True,'HS'),
        (160,5,960,1,True,'HS'),
        (160,5,960,1,True,'HS')
    ]
    def __init__(self,num_classes=17):
        super(MobileNetV3_large,self).__init__()
        self.conv1=nn.Conv2d(3,16,3,2,padding=1,bias=False)
        self.bn1=nn.BatchNorm2d(16)

        self.layers = self._make_layers(in_channels=16)
        self.conv2=nn.Conv2d(160,960,1,stride=1,bias=False)
        self.bn2=nn.BatchNorm2d(960)

        self.conv3=nn.Conv2d(960,1280,1,1,padding=0,bias=True)
        self.conv4=nn.Conv2d(1280,num_classes,1,stride=1,padding=0,bias=True)

    def _make_layers(self,in_channels):
        layers=[]
        for out_channels,kernel_size,exp_channels,stride,se,nl in self.cfg:
            layers.append(
                Bottleneck(in_channels,out_channels,kernel_size,exp_channels,stride,se,nl)
            )
            in_channels=out_channels
        return nn.Sequential(*layers)

    def forward(self,x):
        out=Hswish(self.bn1(self.conv1(x)))
        out=self.layers(out)
        out=Hswish(self.bn2(self.conv2(out)))
        out=F.avg_pool2d(out,7)
        out=Hswish(self.conv3(out))
        out=self.conv4(out)

        a,b=out.size(0),out.size(1)
        out=out.view(a,b)
        return out


MAX_EPOCH=50
BATCH_SIZE=64
LR=0.0001
log_interval=3
val_interval=1

split_dir=os.path.join(".","data","splitData")
train_dir=os.path.join(split_dir,"train")
valid_dir=os.path.join(split_dir,"valid")


train_transform=transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
])


valid_transform=transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
])


train_data=flowerDataset(data_dir=train_dir,transform=train_transform)
valid_data=flowerDataset(data_dir=valid_dir,transform=valid_transform)

train_loader=DataLoader(dataset=train_data,batch_size=BATCH_SIZE,shuffle=True)
valid_loader=DataLoader(dataset=valid_data,batch_size=BATCH_SIZE)

net=MobileNetV3_large(num_classes=17)
if torch.cuda.is_available():
    net.cuda()

criterion=nn.CrossEntropyLoss()

optimizer=optim.Adam(net.parameters(),lr=LR, betas=(0.9, 0.99))

train_curve=list()
valid_curve=list()
net.train()
accurancy_global=0.0
for epoch in range(MAX_EPOCH):
    loss_mean=0.
    correct=0.
    total=0.
    running_loss = 0.0

    for i,data in enumerate(train_loader):
        img,label=data
        img = Variable(img)
        label = Variable(label)
        if torch.cuda.is_available():
            img=img.cuda()
            label=label.cuda()
 
        out=net(img)
        optimizer.zero_grad()
        loss=criterion(out,label)
        print_loss=loss.data.item()

        loss.backward()
        optimizer.step()
        if (i+1)%log_interval==0:
            print('epoch:{},loss:{:.4f}'.format(epoch+1,loss.data.item()))
        _, predicted = torch.max(out.data, 1)
        total += label.size(0)
        correct += (predicted == label).sum()
    print("============================================")
    accurancy=correct / total
    if accurancy>accurancy_global:
        torch.save(net.state_dict(), './weights/best.pkl')
        print("准确率由:", accurancy_global, "上升至:", accurancy, "已更新并保存权值为weights/best.pkl")
        accurancy_global=accurancy
    print('第%d个epoch的识别准确率为:%d%%' % (epoch + 1, 100*accurancy))
torch.save(net.state_dict(), './weights/last.pkl')
print("训练完毕,权重已保存为:weights/last.pkl")

python inference:

class Detector(object):
    def __init__(self,net_kind,num_classes=17):
        super(Detector, self).__init__()
        self.net = torch.jit.load(r"E:\my_files\python_program\python_project\MobileNetV3\jitmodel.pth")
        
        
        self.net.eval()
        if torch.cuda.is_available():
            self.net.cuda()

    def load_weights(self,weight_path):
        self.net.load_state_dict(torch.load(weight_path))

    def detect(self,weight_path,pic_path):
        img = cv2.imread(pic_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        img = cv2.resize(img, (244,244))
        trans = transforms.ToTensor()
        img_tensor = trans(img)
        print(img_tensor.shape)
        img_tensor = torch.unsqueeze(img_tensor, 0)
        
        if torch.cuda.is_available():
            img_tensor=img_tensor.cuda()
        print(img_tensor.shape)
        net_output = self.net(img_tensor)
        print(net_output)
        _, predicted = torch.max(net_output.data, 1)
        result = predicted[0].item()
        print(result)

if __name__=='__main__':
    detector=Detector('large',num_classes=17)
    start=time.time()
    detector.detect('./weights/best.pkl','./6.jpg')
    end=time.time()
    print(end-start)

c++ inference:

void Classfier(cv::Mat &image)
{
  torch::Tensor img_tensor = torch::from_blob(
      image.data, {1, image.rows, image.cols, 3}, torch::kByte);
  img_tensor = img_tensor.permute({0, 3, 1, 2});
  img_tensor = img_tensor.toType(torch::kFloat);
  img_tensor = img_tensor.div(255);
  torch::jit::script::Module module =
      torch::jit::load("/home/peter-233/project/torchtest/jitmodel.pth");
  torch::Tensor output = module.forward({img_tensor}).toTensor();
  std::cout << output << std::endl;
  auto max_result = output.max(1, true);
  auto max_index = std::get<1>(max_result).item<float>();
  std::cout << max_index << std::endl;
}

int main()
{
  cv::Mat image = cv::imread("/home/peter-233/project/torchtest/6.jpg");
  cv::resize(image, image, cv::Size(224, 224));
  cv::cvtColor(image, image, cv::COLOR_BGR2RGB);
  high_resolution_clock::time_point beginTime = high_resolution_clock::now();
  Classfier(image);
  high_resolution_clock::time_point endTime = high_resolution_clock::now();
  milliseconds timeInterval = std::chrono::duration_cast<milliseconds>(endTime - beginTime);
  cout << "Running Time:" << timeInterval.count() << "ms" << endl;
  return 0;
}

```jsx

You could try to use a static input (e.g. torch.ones) and compare the outputs between the Python frontend and libtorch. If these values differ by a larger error it would mean that the model itself is not performing the same operations and you should look into the forward definition to isolate the difference.
On the other hand, if the outputs are equal for a static input, the issue would most likely be in the data loading and preprocessing.

Thanks for your reply. It seems that the issue is in my model.

In that case try to compare the intermediate activations between both approaches to narrow down where the difference might be coming from. If you are using if conditions, check them of course as well to see which code path is taken in the forward method.