com.facebook.jni.CppException: Expected batch2_sizes[0] == bs && batch2_sizes[1] == contraction_size to be true

I try to run a seq2seq nmt model follow the pytorch demo application project.
I transform my own model to lite format using this scripts:

encoder_input=torch.tensor([[[429]]])
encoder_length=torch.tensor([1])

decoder_input = torch.tensor([[[123]]])
encoder_output = torch.zeros([1,1,512])
decoder_step = torch.tensor([0])

decoder_output = torch.zeros([1,1,256])

traced_encoder = torch.jit.trace(quant_encoder, (encoder_input, encoder_length))
args = [decoder_input, encoder_output, decoder_step, encoder_length]
kwargs = {“memory_lengths”: encoder_length}
quant_decoder.init_state(encoder_input, None, None)
traced_decoder = torch.jit.trace(quant_decoder, args, strict=False)
traced_generator = torch.jit.trace(quant_generator, (decoder_output))

traced_encoder_optimized = optimize_for_mobile(traced_encoder)
traced_encoder_optimized._save_for_lite_interpreter(“optimized_encoder_150k.ptl”)
traced_decoder_optimized = optimize_for_mobile(traced_decoder)
traced_decoder_optimized._save_for_lite_interpreter(“optimized_decoder_150k.ptl”)
traced_generator_optimized = optimize_for_mobile(traced_generator)
traced_generator_optimized._save_for_lite_interpreter(“optimized_generator_150k.ptl”)

It looks like it completed successfully, at least without reporting an error. Then I try to run my app throw android like this:

    int input_length = inputs.length;
    final long[] inputShape = new long[]{1, input_length, 1};
    final long[] inputlengthShape = new long[]{input_length};;

    final long[] outputsShape = new long[]{MAX_LENGTH, HIDDEN_SIZE};
    final FloatBuffer outputsTensorBuffer =
            Tensor.allocateFloatBuffer(MAX_LENGTH  * HIDDEN_SIZE);
    LongBuffer inputlengthTensorBuffer = Tensor.allocateLongBuffer(1);
    Tensor inputlengthTensor = Tensor.fromBlob(inputlengthTensorBuffer, inputlengthShape);

    for (int i=0; i<inputs.length; i++) {
        LongBuffer inputTensorBuffer = Tensor.allocateLongBuffer(input_length);
        inputTensorBuffer.put(inputs[i]);

        Tensor inputTensor = Tensor.fromBlob(inputTensorBuffer, inputShape);
        Log.i("INFO input tensor", Arrays.toString(inputTensor.getDataAsLongArray()));
        Log.i("INFO input tensor shape", Arrays.toString(inputTensor.shape()));
        Log.i("INFO input length tensor shape", Arrays.toString(inputlengthTensor.shape()));
        IValue temp = IValue.from(inputTensor);
        Log.i("INFO input tensor test", Arrays.toString(temp.toTensor().getDataAsLongArray()));
        final IValue[] outputTuple = mModuleEncoder.forward(IValue.from(inputTensor), IValue.from(inputlengthTensor)).toTuple();

The input is only one integer representing a character. But it always return an error:

com.facebook.jni.CppException: Expected batch2_sizes[0] == bs && batch2_sizes[1] == contraction_size to be true, but got false. (Could this error message be improved? If so, please report an enhancement request to PyTorch.)

It’s a litte hard to debug this as I can’t get the middle result. I’ve been struggling with this issue for two days…
Any response is appreciate !

Do you see any log outputs and could use it to narrow down the lines of code raising the error?
If so, log the shapes of the used tensors and post the operation here, so that we could have a look why the shape mismatch is raised.

I/INFO input tensor: [3118]
I/INFO input tensor shape: [1, 1, 1]
I/INFO input length tensor shape: [1]
I/INFO input tensor test: [3118]
E/AndroidRuntime: FATAL EXCEPTION: Thread-2
Process: com.example.seq2seqnmt, PID: 23035
com.facebook.jni.CppException: Expected batch2_sizes[0] == bs && batch2_sizes[1] == contraction_size to be true, but got false. (Could this error message be improved? If so, please report an enhancement request to PyTorch.)

  debug_handle:-1
  
Exception raised from bmm_out_or_baddbmm_ at ../aten/src/ATen/native/LinearAlgebra.cpp:1206 (most recent call first):
(no backtrace available)
    at org.pytorch.LiteNativePeer.forward(Native Method)
    at org.pytorch.Module.forward(Module.java:52)
    at org.pytorch.demo.seq2seqnmt.MainActivity.translate(MainActivity.java:220)
    at org.pytorch.demo.seq2seqnmt.MainActivity.run(MainActivity.java:108)
    at java.lang.Thread.run(Thread.java:764)

D/EGL_emulation: eglMakeCurrent: 0xa88d0fe0: ver 3 0 (tinfo 0x982291e0)
E/SpellCheckerSession: ignoring processOrEnqueueTask due to unexpected mState=TASK_CLOSE scp.mWhat=TASK_CLOSE

Thank you for your reply! This is all the log output I get.

Is there any help? :smiley:

The error is raised in the model in a matrix multiplication so I guess in a linear layer.
I don’t know how the model is defined, but based on this thread you might be running into an expected shape mismatch. Check the model implementation and make sure that the tensors have the expected shapes internally. I’m not deeply familiar with the mobile backend, so I can’t exclude a valid bug in the optimization path or the backend itself.

Thanks for your reply! I try to use a standard transformer model and I found a strange error like this:

I/INFO token: 3118
I/INFO input tensor: [3118]
I/INFO input tensor shape: [1, 1, 1]
I/INFO input length tensor: [1]
I/INFO input length tensor shape: [1]
I/INFO input tensor test: [3118]
E/AndroidRuntime: FATAL EXCEPTION: Thread-2
Process: com.example.seq2seqnmt, PID: 12138
com.facebook.jni.CppException: shape ‘[2, -1, 8, 32]’ is invalid for input of size 256
debug_handle:-1

Exception raised from run at ../torch/csrc/jit/mobile/module.cpp:206 (most recent call first):
(no backtrace available)
    at org.pytorch.LiteNativePeer.forward(Native Method)
    at org.pytorch.Module.forward(Module.java:52)
    at org.pytorch.demo.seq2seqnmt.MainActivity.translate(MainActivity.java:237)
    at org.pytorch.demo.seq2seqnmt.MainActivity.run(MainActivity.java:108)
    at java.lang.Thread.run(Thread.java:764)

D/EGL_emulation: eglMakeCurrent: 0xa8a850c0: ver 3 0 (tinfo 0xa8a83210)
E/SpellCheckerSession: ignoring processOrEnqueueTask due to unexpected mState=TASK_CLOSE scp.mWhat=TASK_CLOSE

When I transforme the pytorch model to lite model, my input batch size is 2. And the input batch size now is 1 as the log shown. But it seems the batch size which is in the first dimension is still 2, that’s a little weired.
I know that tf-lite model will keep the length or batch size all the same when transforming its python model to lite or cpp model. So is pytorch do the same things like fix the sequence length and batch size. If so, should I do a resize on the model at the time of inference, as far as I know tf-lite will have this handled.

If I’m not mistaken, torch.jit.trace would use static shapes for dim1+ but would still allow to use a flexible batch size (toch.jit.script would allow more flexible inputs as well as data-dependent control flow).
The current issue is raised in a view or reshape operation, which apparently tries to reshape the tensor to a shape of [2, -1, 8, 32] (512 elements) while the input only contains 256 elements.

Yes, my hidden size is 256, and the input is only one token so the input of reshape should be a (1, 1, 256) tensor. The problem now is that the batch size has somehow become 2.

I try to fix the sequence size to 100 and padding the input which is shorter than 100 words. Then the encoder part looks good, but the decoder part has a new error said:

E/AndroidRuntime: FATAL EXCEPTION: Thread-2
    Process: com.example.seq2seqnmt, PID: 17870
    com.facebook.jni.CppException: isObject()INTERNAL ASSERT FAILED at "../aten/src/ATen/core/ivalue_inl.h":115, please report a bug to PyTorch. Expected Object but got Tensor
      
      debug_handle:-1
      
    Exception raised from toObject at ../aten/src/ATen/core/ivalue_inl.h:115 (most recent call first):
    (no backtrace available)
        at org.pytorch.LiteNativePeer.forward(Native Method)
        at org.pytorch.Module.forward(Module.java:52)
        at org.pytorch.demo.seq2seqnmt.MainActivity.translate(MainActivity.java:283)
        at org.pytorch.demo.seq2seqnmt.MainActivity.run(MainActivity.java:108)
        at java.lang.Thread.run(Thread.java:764)

I can’t see what exactly is failing, but as the error message suggests, could you create an issue on GitHub please?

Yes, I have created an issus on Github. Hope there will be some advices and thank you so much for your help! :smiley: