How to implement pytorchvideo from input as images

Hi everyone
I am deverloping pytorchvideo on android. The input model is a 10-image. I run it on Python then I get a result that is ok. But when I run it on Android java I get a result that is not the same with Python. I don’t know how to fix it. plz help me to fix it
This is a python code

import torch
import os
from PIL import Image
from torchvision import transforms

transform = transforms.Compose([
    transforms.Resize((160, 160)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.45, 0.45, 0.45], std=[0.225, 0.225, 0.225])  # Normalize theo ImageNet
])

model = torch.jit.load("model_video_cls_pytorch/model.ptl")

image_dir = "imgs"
image_paths = [os.path.join(image_dir, img) for img in os.listdir(image_dir) if img.endswith(".jpg")]

image_paths = sorted(image_paths)[:10]

image_tensors = []
for img_path in image_paths:
    image = Image.open(img_path).convert("RGB")
    image_tensor = transform(image)  # Áp dụng transform
    image_tensors.append(image_tensor)

image_tensors = torch.stack(image_tensors)
image_tensors = image_tensors.permute(1, 0, 2, 3)
final_tensor = image_tensors.unsqueeze(0)

print("Final Tensor Shape:", final_tensor.shape)

output = model(final_tensor)
print("output : ", output)

This is an Android code when I run the function getResult. I will read 10 images from assets and convert them to Tensor. After running the inferred model then the output model is not sampled with python. Although I use sample model and sample images

public final class Constants {

    public final static float[] MEAN_RGB = new float[] {0.45f, 0.45f, 0.45f};
    public final static float[] STD_RGB = new float[] {0.225f, 0.225f, 0.225f};
    public final static int COUNT_OF_FRAMES_PER_INFERENCE = 10;
    public final static int TARGET_VIDEO_SIZE = 160;
    public final static int MODEL_INPUT_SIZE = COUNT_OF_FRAMES_PER_INFERENCE * 3 * TARGET_VIDEO_SIZE * TARGET_VIDEO_SIZE;
    public final static int TOP_COUNT = 5;
}


private Module mModule = null;

@Override
protected void onCreate(Bundle savedInstanceState) {
    super.onCreate(savedInstanceState);

    try {
        mModule = LiteModuleLoader.load(assetFilePath(getApplicationContext(), "model.ptl"));
        getResult();
    } catch (IOException e) {
        throw new RuntimeException(e);
    }
}

public static Bitmap loadImageFromAssets(Context context, String fileName) {
    Bitmap bitmap = null;
    try {
        // Lấy AssetManager
        InputStream inputStream = context.getAssets().open(fileName);
        // Tạo Bitmap từ InputStream
        bitmap = BitmapFactory.decodeStream(inputStream);
        inputStream.close();
    } catch (IOException e) {
        e.printStackTrace();
    }
    return bitmap;
}

private Pair<Integer[], Long> getResult() {

    FloatBuffer inTensorBuffer = Tensor.allocateFloatBuffer(Constants.MODEL_INPUT_SIZE);
    String[] files = {
            "imgs/0.jpg",
            "imgs/1.jpg",
            "imgs/2.jpg",
            "imgs/3.jpg",
            "imgs/4.jpg",
            "imgs/5.jpg",
            "imgs/6.jpg",
            "imgs/7.jpg",
            "imgs/8.jpg",
            "imgs/9.jpg"
    };
    for (int i = 0; i < Constants.COUNT_OF_FRAMES_PER_INFERENCE; i++) {
        Bitmap bitmap = loadImageFromAssets(this, files[i]).copy(Bitmap.Config.ARGB_8888, true);
        TensorImageUtils.bitmapToFloatBuffer(bitmap, 0, 0,
                Constants.TARGET_VIDEO_SIZE, Constants.TARGET_VIDEO_SIZE, Constants.MEAN_RGB, Constants.STD_RGB, inTensorBuffer,
                (3) * i * Constants.TARGET_VIDEO_SIZE * Constants.TARGET_VIDEO_SIZE);
    }
    Tensor inputTensor = Tensor.fromBlob(inTensorBuffer, new long[]{1, 3, Constants.COUNT_OF_FRAMES_PER_INFERENCE, 160, 160});

    final long startTime = SystemClock.elapsedRealtime();
    Tensor outputTensor = mModule.forward(IValue.from(inputTensor)).toTensor();
    final long inferenceTime = SystemClock.elapsedRealtime() - startTime;

    final float[] scores = outputTensor.getDataAsFloatArray();

    for (int i = 0; i < scores.length; i++) {
        Log.d("PytorchVideo", "score " + scores[i]);
    }

    Integer scoresIdx[] = new Integer[scores.length];
    for (int i = 0; i < scores.length; i++)
        scoresIdx[i] = i;

    Arrays.sort(scoresIdx, new Comparator<Integer>() {
        @Override public int compare(final Integer o1, final Integer o2) {
            return Float.compare(scores[o2], scores[o1]);
        }
    });

    return new Pair<>(scoresIdx, inferenceTime);
}