Converting Donut model to Onnx cause differents outputs compared to pytorch

Hi everyone,

I’m trying to convert the Donut model to Onnx (this one from huggingface : naver-clova-ix/donut-base)

i managed to understand that i need to separate the model by 2, first part the encoder and the second the decoder
The encoder export seems to work well but the decoder in onnx format give totally differents results than pytorch results which lead to wrong prediction

When i’m using the torch.onnx.verification.find_mismatch() function i got this :

===================== Mismatch info for graph partition : ======================
================================ Mismatch error ================================
Tensor-likes are not close!

Mismatched elements: 4 / 115050 (0.0%)
Greatest absolute difference: 4.604458808898926e-06 at index (0, 1, 11195) (up to 1e-07 allowed)
Greatest relative difference: 0.0069338479079306126 at index (0, 1, 11195) (up to 0.001 allowed)
==================================== Tree: =====================================
497 X __248 X __124 ✓
id: | id: 0 | id: 00
| |
| |__124 X __62 X __31 ✓
| id: 01 | id: 010 | id: 0100
| | |
| | |__31 X __15 X __7 ✓
| | id: 0101 | id: 01010 | id: 010100
| | | |
| | | |__8 ✓
| | | id: 010101
| | |
| | |__16 ✓
| | id: 01011
| |
| |__62 ✓
| id: 011
|
|__249 X __124 X __62 X __31 X __15 ✓
id: 1 | id: 10 | id: 100 | id: 1000 | id: 10000
| | | |
| | | |__16 X __8 X __4 ✓
| | | id: 10001 | id: 100010 | id: 1000100
| | | | |
| | | | |__4 ✓
| | | | id: 1000101
| | | |
| | | |__8 X __4 ✓
| | | id: 100011 | id: 1000110
| | | |
| | | |__4 ✓
| | | id: 1000111
| | |
| | |__31 ✓
| | id: 1001
| |
| |__62 ✓
| id: 101
|
|__125 X __62 X __31 X __15 X __7 X __3 ✓
id: 11 | id: 110 | id: 1100 | id: 11000 | id: 110000 | id: 1100000
| | | | |
| | | | |__4 ✓
| | | | id: 1100001
| | | |
| | | |__8 X __4 ✓
| | | id: 110001 | id: 1100010
| | | |
| | | |__4 ✓
| | | id: 1100011
| | |
| | |__16 ✓
| | id: 11001
| |
| |__31 X __15 X __7 X __3 ✓
| id: 1101 | id: 11010 | id: 110100 | id: 1101000
| | | |
| | | |__4 X __2 ✓
| | | id: 1101001 | id: 11010010
| | | |
| | | |__2 X __1 ✓
| | | id: 11010011 | id: 110100110
| | | |
| | | |__1 X (aten::linear)
| | | id: 110100111
| | |
| | |__8 ✓
| | id: 110101
| |
| |__16 ✓
| id: 11011
|
|__63 X __31 ✓
id: 111 | id: 1110
|
|__32 X __16 ✓
id: 1111 | id: 11110
|
|__16 ✓
id: 11111
=========================== Mismatch leaf subgraphs: ===========================
[‘01010’, ‘100010’, ‘100011’, ‘110000’, ‘110001’, ‘110100111’, ‘1111’]
============================= Mismatch node kinds: =============================
{‘aten::linear’: 1}


For the leaf with the id “110100111”, the function find_partition give me :

GraphInfo(graph=graph(%aten::reshape_1284 : Float(1, 16, 2, 64, strides=[2048, 128, 64, 1], requires_grad=0, device=cpu),
%aten::size_1260 : Float(1, 2, 1024, strides=[2048, 1024, 1], requires_grad=0, device=cpu),
%model.decoder.layers.3.self_attn.v_proj.weight : Float(1024, 1024, strides=[1024, 1], requires_grad=1, device=cpu),
%model.decoder.layers.3.self_attn.v_proj.bias : Float(1024, strides=[1], requires_grad=1, device=cpu)):
%aten::view_1285 : Float(1, 2, 1024, strides=[2048, 1024, 1], requires_grad=0, device=cpu) = aten::linear(%aten::size_1260, %model.decoder.layers.3.self_attn.v_proj.weight, %model.decoder.layers.3.self_attn.v_proj.bias) # C:\Users\pierre.dumas\AppData\Local\anaconda3\envs\Donut\Lib\site-packages\torch\nn\modules\linear.py:116:0
return (%aten::view_1285)
, input_args=(tensor([[[[ 0.0417, 0.3818, -0.5706, …, 0.4427, -0.2359, 0.6942],
[ 0.0660, -0.1410, -0.5272, …, -0.2237, -0.1110, 0.8723]],

     [[ 1.5251, -1.0646,  0.3313,  ..., -1.4080, -2.2913,  0.1199],
      [ 0.4624, -2.5551,  0.6143,  ...,  0.6368, -0.0665, -0.4362]],

     [[ 0.2593,  0.6201,  0.1564,  ..., -0.2180, -0.5958, -0.7321],
      [ 0.3348,  1.0992,  0.0550,  ...,  0.0353, -1.0144, -0.5455]],

     ...,

     [[ 0.0432,  0.4610,  0.9161,  ..., -1.1195,  1.6831, -1.3095],
      [ 0.5505, -0.3222,  0.9189,  ..., -1.2325,  1.3278, -1.0136]],

     [[ 2.0462, -0.0515,  0.9965,  ...,  1.2463, -0.8142, -1.7320],
      [ 1.4730,  0.1691,  0.6996,  ...,  0.2591, -0.5551, -1.5194]],

     [[ 2.4220,  0.0410,  4.0391,  ..., -1.6271,  0.4618,  0.7287],
      [ 1.1739,  0.7379,  3.0418,  ..., -1.1209,  1.5043,  0.0819]]]]), tensor([[[ 0.0397,  0.5798, -0.3321,  ...,  0.5298,  0.5357, -0.0300],
     [-0.0603,  0.6514, -0.5388,  ...,  0.0446,  0.4127,  0.6083]]])), params_dict={'model.decoder.layers.3.self_attn.v_proj.weight': tensor([[-0.0052, -0.0568,  0.0551,  ..., -0.0066, -0.0582,  0.0492],
     [-0.0234, -0.0003, -0.0131,  ...,  0.0322, -0.0577, -0.0016],
     [-0.0661, -0.0218, -0.0131,  ..., -0.0362,  0.1081,  0.0036],
     ...,
     [ 0.0139, -0.0892, -0.0718,  ...,  0.0563,  0.0103, -0.0669],
     [-0.0128,  0.0929,  0.0072,  ...,  0.0327, -0.0292, -0.0834],
    [-0.0680, -0.0546, -0.0941,  ...,  0.0471, -0.0535, -0.0479]]), 'model.decoder.layers.3.self_attn.v_proj.bias': tensor([ 0.0026, -0.0165,  0.0110,  ..., -0.0932, -0.0076, -0.0342])}, export_options=ExportOptions(export_params=True, verbose=False, training=<TrainingMode.EVAL: 0>, input_names=None, output_names=None, operator_export_type=<OperatorExportTypes.ONNX: 0>, opset_version=17, do_constant_folding=False, dynamic_axes=None, keep_initializers_as_inputs=True, custom_opsets=None, export_modules_as_functions=False), mismatch_error=AssertionError('Tensor-likes are not close!\n\nMismatched elements: 1 / 2048 (0.0%)\nGreatest absolute difference: 3.8743019104003906e-07 at index (0, 1, 493) (up to 1e-07 allowed)\nGreatest relative difference: 0.001469829585403204 at index (0, 1, 493) (up to 0.001 allowed)'), pt_outs=[tensor([[[ 0.1806, -0.3847,  0.0795,  ..., -0.4868, -0.4938, -0.3816],
      [-0.0113,  0.1197, -0.0161,  ..., -0.5222, -0.5008,  0.1503]]])], upper_graph_info=None, lower_graph_info=None, id='110100111', _onnx_graph=graph(%aten::size_1260 : Float(1, 2, 1024, strides=[2048, 1024, 1], requires_grad=0, device=cpu),
   %model.decoder.layers.3.self_attn.v_proj.weight : Float(1024, 1024, strides=[1024, 1], requires_grad=0, device=cpu),
   %model.decoder.layers.3.self_attn.v_proj.bias : Float(1024, strides=[1], requires_grad=0, device=cpu)):

%4 : Float(1024, 1024, strides=[1024, 1], device=cpu) = onnx::Transposeperm=[1, 0] # C:\Users\pierre.dumas\AppData\Local\anaconda3\envs\Donut\Lib\site-packages\torch\nn\modules\linear.py:116:0
%5 : Float(1, 2, 1024, strides=[2048, 1024, 1], device=cpu) = onnx::MatMul(%aten::size_1260, %4) # C:\Users\pierre.dumas\AppData\Local\anaconda3\envs\Donut\Lib\site-packages\torch\nn\modules\linear.py:116:0
%aten::view_1285 : Float(1, 2, 1024, strides=[2048, 1024, 1], requires_grad=0, device=cpu) = onnx::Add(%model.decoder.layers.3.self_attn.v_proj.bias, %5) # C:\Users\pierre.dumas\AppData\Local\anaconda3\envs\Donut\Lib\site-packages\torch\nn\modules\linear.py:116:0
return (%aten::view_1285)
, _EXCLUDED_NODE_KINDS=frozenset({‘prim::Constant’, ‘aten::ScalarImplicit’, ‘prim::ListConstruct’}))

Knowing all of that, i don’t know how i can solve my problem with these informations

To reproduce, here is the code :

from transformers import DonutProcessor, VisionEncoderDecoderModel
import torch




processor = DonutProcessor.from_pretrained(r"naver-clova-ix/donut-base", cache_dir = r".\cache_dir")
model = VisionEncoderDecoderModel.from_pretrained(r"naver-clova-ix/donut-base", cache_dir = r".\cache_dir")

from PIL import Image


img_path = r".\image.png"
img = Image.open(img_path).convert("RGB")

task_prompt = "<s_iitcdip>"
decoder_input_ids = processor.tokenizer(task_prompt , add_special_tokens=False, return_tensors="pt").input_ids

pixel_values = processor(img, return_tensors="pt").pixel_values


import time
import torch

val_encod = model.encoder(pixel_values)
print(val_encod.last_hidden_state)
val_decod = model.decoder(torch.tensor([[0, 57500]]),
                    attention_mask = torch.tensor([[ 1, 1]]),
                    encoder_hidden_states = val_encod.last_hidden_state, 
                    use_cache = True, 
                    output_attentions = False, 
                    output_hidden_states = False,
                    return_dict = True
                   )
print(val_decod.logits)
print(torch.argmax(val_decod.logits[0][1]))

To export models :

#Decoder
import os

current_path = os.getcwd()
model.decoder.eval()
save_path_onnx = current_path + r"\onnx_models\donut_decoder.onnx"
with torch.no_grad():
    torch.onnx.export(
        model.decoder,
        (decoder_input_ids, torch.tensor([[1]]), val_encod.last_hidden_state),
        save_path_onnx,
        export_params=True,
        do_constant_folding=False,
        opset_version=17,
        input_names = ['decoder_input_ids', 'attention_mask', 'encoded_image'],
        output_names = ['logit',  'past_key_values'],
        verbose=False,
        dynamic_axes={'decoder_input_ids' : {1 : 'number_of_token'}, 'attention_mask' : {1 : 'number_of_token'},
                      'logit' : {1 : 'number_of_token'}}
    )


#Encoder
import os

current_path = os.getcwd()

save_path_onnx = current_path + r"\onnx_models\donut_encoder.onnx"

torch.onnx.export(
    model.encoder,
    (pixel_values),
    save_path_onnx,
    export_params=True,
    do_constant_folding=True,
    opset_version=17,
    input_names = ['pixel_values'],
    output_names = ['encoded_image', 'pooler_output'],
    verbose=False
)

To execute the onnx model :

import onnxruntime as rt
import numpy as np
import time


save_path_encoder_onnx = r".\onnx_models\donut_encoder.onnx"
save_path_decoder_onnx = r".\onnx_models\donut_decoder.onnx"
donut_encoder_onnx = rt.InferenceSession(save_path_encoder_onnx, providers=rt.get_available_providers())
donut_decoder_onnx = rt.InferenceSession(save_path_decoder_onnx, providers=rt.get_available_providers())

number_of_execution=1
t_start_onnx_model = time.time()
encoded_image = donut_encoder_onnx.run(None, {"pixel_values": np.array(pixel_values)})
pred = 0
number_loop =0
print(encoded_image[0])
input = np.array([[0, 57500]])
attention_mask = np.array([[1, 1]], dtype=np.int64)
#while(number_loop < 1 or pred==2):
output_decoder = donut_decoder_onnx.run(None, {"decoder_input_ids": np.array(input, dtype = "int64"),
                                               "attention_mask": attention_mask,
                                              "encoded_image": encoded_image[0]})#encoded_image[0]})
attention_mask = np.append(attention_mask, [[1]], axis=1)
pred = np.argmax(output_decoder[0][0][0])
print("pred : ", pred)
print("attention_mask : ", attention_mask)
input = np.append(input, [[pred]], axis=1)
number_loop+=1
print(input)

print("Inference Time for Onnx model opti : ", (time.time() - t_start_onnx_model)/number_of_execution ," seconds")
print(output_decoder[0])

Verification :

from torch.onnx import verification
input_tuple = (torch.tensor([[10, 52000]]), torch.tensor([[1, 1]]), torch.ones([1, 690, 1024]))
info_g = verification.find_mismatch(
    model.decoder, input_tuple, opset_version=17, do_constant_folding=False
)