Deploying pytorch 1D CNN model on android app using java

Hi, I’m trying to deploy 1D CNN based multi-label signal classification model.

However, the output on python code is drastically different with java though I used same model which is deployed using torch.jit.trace.

Please give me any advice I can apply.
I’m trying to solve this problem for one week.

[Python]
  1. model : 1D CNN based DenseNet

  2. train data input size : (Batch, channel, length)=(25,1,1000)

  3. output class : 4 class / multi label classification

  4. version of python and pytorch

  • python : 3.8.8
  • pytorch : 1.9.0+cu111
  1. input/output data type : torch_float32

[Java]

  1. pytorch_android version : 1.9.1
  2. input/output data type : float32(float)
  3. input tensor data :
  • real_output : float[1000]
  • shape : long[]{1,1,1000}

[Java Code]
final Tensor inputTensor = Tensor.fromBlob(real_output, shape, MemoryFormat.CONTIGUOUS);
final Tensor ai_output = module.forward(IValue.from(inputTensor)).toTensor();
final float[] ai_result = ai_output.getDataAsFloatArray();
for (int i=0;i<4;i++){
Log.d(TAG, “ai_result_” + i + " : " + ai_result[i]);
}