Hi everyone,
I’m trying to convert the Donut model to Onnx (this one from huggingface : naver-clova-ix/donut-base)
i managed to understand that i need to separate the model by 2, first part the encoder and the second the decoder
The encoder export seems to work well but the decoder in onnx format give totally differents results than pytorch results which lead to wrong prediction
When i’m using the torch.onnx.verification.find_mismatch() function i got this :
===================== Mismatch info for graph partition : ======================
================================ Mismatch error ================================
Tensor-likes are not close!Mismatched elements: 4 / 115050 (0.0%)
Greatest absolute difference: 4.604458808898926e-06 at index (0, 1, 11195) (up to 1e-07 allowed)
Greatest relative difference: 0.0069338479079306126 at index (0, 1, 11195) (up to 0.001 allowed)
==================================== Tree: =====================================
497 X __248 X __124 ✓
id: | id: 0 | id: 00
| |
| |__124 X __62 X __31 ✓
| id: 01 | id: 010 | id: 0100
| | |
| | |__31 X __15 X __7 ✓
| | id: 0101 | id: 01010 | id: 010100
| | | |
| | | |__8 ✓
| | | id: 010101
| | |
| | |__16 ✓
| | id: 01011
| |
| |__62 ✓
| id: 011
|
|__249 X __124 X __62 X __31 X __15 ✓
id: 1 | id: 10 | id: 100 | id: 1000 | id: 10000
| | | |
| | | |__16 X __8 X __4 ✓
| | | id: 10001 | id: 100010 | id: 1000100
| | | | |
| | | | |__4 ✓
| | | | id: 1000101
| | | |
| | | |__8 X __4 ✓
| | | id: 100011 | id: 1000110
| | | |
| | | |__4 ✓
| | | id: 1000111
| | |
| | |__31 ✓
| | id: 1001
| |
| |__62 ✓
| id: 101
|
|__125 X __62 X __31 X __15 X __7 X __3 ✓
id: 11 | id: 110 | id: 1100 | id: 11000 | id: 110000 | id: 1100000
| | | | |
| | | | |__4 ✓
| | | | id: 1100001
| | | |
| | | |__8 X __4 ✓
| | | id: 110001 | id: 1100010
| | | |
| | | |__4 ✓
| | | id: 1100011
| | |
| | |__16 ✓
| | id: 11001
| |
| |__31 X __15 X __7 X __3 ✓
| id: 1101 | id: 11010 | id: 110100 | id: 1101000
| | | |
| | | |__4 X __2 ✓
| | | id: 1101001 | id: 11010010
| | | |
| | | |__2 X __1 ✓
| | | id: 11010011 | id: 110100110
| | | |
| | | |__1 X (aten::linear)
| | | id: 110100111
| | |
| | |__8 ✓
| | id: 110101
| |
| |__16 ✓
| id: 11011
|
|__63 X __31 ✓
id: 111 | id: 1110
|
|__32 X __16 ✓
id: 1111 | id: 11110
|
|__16 ✓
id: 11111
=========================== Mismatch leaf subgraphs: ===========================
[‘01010’, ‘100010’, ‘100011’, ‘110000’, ‘110001’, ‘110100111’, ‘1111’]
============================= Mismatch node kinds: =============================
{‘aten::linear’: 1}
For the leaf with the id “110100111”, the function find_partition give me :
GraphInfo(graph=graph(%aten::reshape_1284 : Float(1, 16, 2, 64, strides=[2048, 128, 64, 1], requires_grad=0, device=cpu),
%aten::size_1260 : Float(1, 2, 1024, strides=[2048, 1024, 1], requires_grad=0, device=cpu),
%model.decoder.layers.3.self_attn.v_proj.weight : Float(1024, 1024, strides=[1024, 1], requires_grad=1, device=cpu),
%model.decoder.layers.3.self_attn.v_proj.bias : Float(1024, strides=[1], requires_grad=1, device=cpu)):
%aten::view_1285 : Float(1, 2, 1024, strides=[2048, 1024, 1], requires_grad=0, device=cpu) = aten::linear(%aten::size_1260, %model.decoder.layers.3.self_attn.v_proj.weight, %model.decoder.layers.3.self_attn.v_proj.bias) # C:\Users\pierre.dumas\AppData\Local\anaconda3\envs\Donut\Lib\site-packages\torch\nn\modules\linear.py:116:0
return (%aten::view_1285)
, input_args=(tensor([[[[ 0.0417, 0.3818, -0.5706, …, 0.4427, -0.2359, 0.6942],
[ 0.0660, -0.1410, -0.5272, …, -0.2237, -0.1110, 0.8723]],[[ 1.5251, -1.0646, 0.3313, ..., -1.4080, -2.2913, 0.1199], [ 0.4624, -2.5551, 0.6143, ..., 0.6368, -0.0665, -0.4362]], [[ 0.2593, 0.6201, 0.1564, ..., -0.2180, -0.5958, -0.7321], [ 0.3348, 1.0992, 0.0550, ..., 0.0353, -1.0144, -0.5455]], ..., [[ 0.0432, 0.4610, 0.9161, ..., -1.1195, 1.6831, -1.3095], [ 0.5505, -0.3222, 0.9189, ..., -1.2325, 1.3278, -1.0136]], [[ 2.0462, -0.0515, 0.9965, ..., 1.2463, -0.8142, -1.7320], [ 1.4730, 0.1691, 0.6996, ..., 0.2591, -0.5551, -1.5194]], [[ 2.4220, 0.0410, 4.0391, ..., -1.6271, 0.4618, 0.7287], [ 1.1739, 0.7379, 3.0418, ..., -1.1209, 1.5043, 0.0819]]]]), tensor([[[ 0.0397, 0.5798, -0.3321, ..., 0.5298, 0.5357, -0.0300], [-0.0603, 0.6514, -0.5388, ..., 0.0446, 0.4127, 0.6083]]])), params_dict={'model.decoder.layers.3.self_attn.v_proj.weight': tensor([[-0.0052, -0.0568, 0.0551, ..., -0.0066, -0.0582, 0.0492],
[-0.0234, -0.0003, -0.0131, ..., 0.0322, -0.0577, -0.0016],
[-0.0661, -0.0218, -0.0131, ..., -0.0362, 0.1081, 0.0036],
...,
[ 0.0139, -0.0892, -0.0718, ..., 0.0563, 0.0103, -0.0669],
[-0.0128, 0.0929, 0.0072, ..., 0.0327, -0.0292, -0.0834],
[-0.0680, -0.0546, -0.0941, ..., 0.0471, -0.0535, -0.0479]]), 'model.decoder.layers.3.self_attn.v_proj.bias': tensor([ 0.0026, -0.0165, 0.0110, ..., -0.0932, -0.0076, -0.0342])}, export_options=ExportOptions(export_params=True, verbose=False, training=<TrainingMode.EVAL: 0>, input_names=None, output_names=None, operator_export_type=<OperatorExportTypes.ONNX: 0>, opset_version=17, do_constant_folding=False, dynamic_axes=None, keep_initializers_as_inputs=True, custom_opsets=None, export_modules_as_functions=False), mismatch_error=AssertionError('Tensor-likes are not close!\n\nMismatched elements: 1 / 2048 (0.0%)\nGreatest absolute difference: 3.8743019104003906e-07 at index (0, 1, 493) (up to 1e-07 allowed)\nGreatest relative difference: 0.001469829585403204 at index (0, 1, 493) (up to 0.001 allowed)'), pt_outs=[tensor([[[ 0.1806, -0.3847, 0.0795, ..., -0.4868, -0.4938, -0.3816],
[-0.0113, 0.1197, -0.0161, ..., -0.5222, -0.5008, 0.1503]]])], upper_graph_info=None, lower_graph_info=None, id='110100111', _onnx_graph=graph(%aten::size_1260 : Float(1, 2, 1024, strides=[2048, 1024, 1], requires_grad=0, device=cpu),
%model.decoder.layers.3.self_attn.v_proj.weight : Float(1024, 1024, strides=[1024, 1], requires_grad=0, device=cpu),
%model.decoder.layers.3.self_attn.v_proj.bias : Float(1024, strides=[1], requires_grad=0, device=cpu)):
%4 : Float(1024, 1024, strides=[1024, 1], device=cpu) = onnx::Transposeperm=[1, 0] # C:\Users\pierre.dumas\AppData\Local\anaconda3\envs\Donut\Lib\site-packages\torch\nn\modules\linear.py:116:0
%5 : Float(1, 2, 1024, strides=[2048, 1024, 1], device=cpu) = onnx::MatMul(%aten::size_1260, %4) # C:\Users\pierre.dumas\AppData\Local\anaconda3\envs\Donut\Lib\site-packages\torch\nn\modules\linear.py:116:0
%aten::view_1285 : Float(1, 2, 1024, strides=[2048, 1024, 1], requires_grad=0, device=cpu) = onnx::Add(%model.decoder.layers.3.self_attn.v_proj.bias, %5) # C:\Users\pierre.dumas\AppData\Local\anaconda3\envs\Donut\Lib\site-packages\torch\nn\modules\linear.py:116:0
return (%aten::view_1285)
, _EXCLUDED_NODE_KINDS=frozenset({‘prim::Constant’, ‘aten::ScalarImplicit’, ‘prim::ListConstruct’}))
Knowing all of that, i don’t know how i can solve my problem with these informations
To reproduce, here is the code :
from transformers import DonutProcessor, VisionEncoderDecoderModel
import torch
processor = DonutProcessor.from_pretrained(r"naver-clova-ix/donut-base", cache_dir = r".\cache_dir")
model = VisionEncoderDecoderModel.from_pretrained(r"naver-clova-ix/donut-base", cache_dir = r".\cache_dir")
from PIL import Image
img_path = r".\image.png"
img = Image.open(img_path).convert("RGB")
task_prompt = "<s_iitcdip>"
decoder_input_ids = processor.tokenizer(task_prompt , add_special_tokens=False, return_tensors="pt").input_ids
pixel_values = processor(img, return_tensors="pt").pixel_values
import time
import torch
val_encod = model.encoder(pixel_values)
print(val_encod.last_hidden_state)
val_decod = model.decoder(torch.tensor([[0, 57500]]),
attention_mask = torch.tensor([[ 1, 1]]),
encoder_hidden_states = val_encod.last_hidden_state,
use_cache = True,
output_attentions = False,
output_hidden_states = False,
return_dict = True
)
print(val_decod.logits)
print(torch.argmax(val_decod.logits[0][1]))
To export models :
#Decoder
import os
current_path = os.getcwd()
model.decoder.eval()
save_path_onnx = current_path + r"\onnx_models\donut_decoder.onnx"
with torch.no_grad():
torch.onnx.export(
model.decoder,
(decoder_input_ids, torch.tensor([[1]]), val_encod.last_hidden_state),
save_path_onnx,
export_params=True,
do_constant_folding=False,
opset_version=17,
input_names = ['decoder_input_ids', 'attention_mask', 'encoded_image'],
output_names = ['logit', 'past_key_values'],
verbose=False,
dynamic_axes={'decoder_input_ids' : {1 : 'number_of_token'}, 'attention_mask' : {1 : 'number_of_token'},
'logit' : {1 : 'number_of_token'}}
)
#Encoder
import os
current_path = os.getcwd()
save_path_onnx = current_path + r"\onnx_models\donut_encoder.onnx"
torch.onnx.export(
model.encoder,
(pixel_values),
save_path_onnx,
export_params=True,
do_constant_folding=True,
opset_version=17,
input_names = ['pixel_values'],
output_names = ['encoded_image', 'pooler_output'],
verbose=False
)
To execute the onnx model :
import onnxruntime as rt
import numpy as np
import time
save_path_encoder_onnx = r".\onnx_models\donut_encoder.onnx"
save_path_decoder_onnx = r".\onnx_models\donut_decoder.onnx"
donut_encoder_onnx = rt.InferenceSession(save_path_encoder_onnx, providers=rt.get_available_providers())
donut_decoder_onnx = rt.InferenceSession(save_path_decoder_onnx, providers=rt.get_available_providers())
number_of_execution=1
t_start_onnx_model = time.time()
encoded_image = donut_encoder_onnx.run(None, {"pixel_values": np.array(pixel_values)})
pred = 0
number_loop =0
print(encoded_image[0])
input = np.array([[0, 57500]])
attention_mask = np.array([[1, 1]], dtype=np.int64)
#while(number_loop < 1 or pred==2):
output_decoder = donut_decoder_onnx.run(None, {"decoder_input_ids": np.array(input, dtype = "int64"),
"attention_mask": attention_mask,
"encoded_image": encoded_image[0]})#encoded_image[0]})
attention_mask = np.append(attention_mask, [[1]], axis=1)
pred = np.argmax(output_decoder[0][0][0])
print("pred : ", pred)
print("attention_mask : ", attention_mask)
input = np.append(input, [[pred]], axis=1)
number_loop+=1
print(input)
print("Inference Time for Onnx model opti : ", (time.time() - t_start_onnx_model)/number_of_execution ," seconds")
print(output_decoder[0])
Verification :
from torch.onnx import verification
input_tuple = (torch.tensor([[10, 52000]]), torch.tensor([[1, 1]]), torch.ones([1, 690, 1024]))
info_g = verification.find_mismatch(
model.decoder, input_tuple, opset_version=17, do_constant_folding=False
)