I try to export a Decoder model of a Seq2Seq model. The definition of the class is as follow:
class Decoder(MarianPreTrainedModel):
def __init__(self, config: MarianConfig):
super().__init__(config)
self.dropout = config.dropout
self.layerdrop = config.decoder_layerdrop
self.padding_idx = config.pad_token_id
self.max_target_positions = config.max_position_embeddings
self.embed_positions = MarianSinusoidalPositionalEmbedding(
config.max_position_embeddings,
config.d_model,
self.padding_idx,
)
self.layers = nn.ModuleList([MarianDecoderLayer(config) for _ in range(config.decoder_layers)])
self.init_weights()
....
def forward(
self,
inputs_embeds: torch.Tensor,
encoder_hidden_states: torch.Tensor,
encoder_attention_mask: torch.Tensor,
decoder_past_key_values: PAST_KEY_VALUES_TYPE = None,
encoder_past_key_values: PAST_KEY_VALUES_TYPE = None,
):
....
The Module is modified from huggingface MarianMTModel.
When I export the model to ONNX, the eported model lost the arg encoder_hidden_states
.
I have debug it in the export funtion, and I find the arg lost after this line:
params_dict = torch._C._jit_pass_onnx_eliminate_unused_items(graph, params_dict)
The 509th line in torch.onnx.utils.py
Before the funtion, I print the graph:
graph(%decoder_input_embeds : Float(*, *, 512, strides=[1024, 512, 1], requires_grad=0, device=cpu),
%encoder_hidden_states : Float(*, *, 512, strides=[3584, 512, 1], requires_grad=0, device=cpu),
%encoder_attention_mask : Long(*, *, strides=[7, 1], requires_grad=0, device=cpu),
%decoder_cache_values : Float(*, 8, *, 64, strides=[1536, 192, 64, 1], requires_grad=0, device=cpu),
%encoder_cache_values : Float(*, 8, *, 64, strides=[3584, 448, 64, 1], requires_grad=0, device=cpu),
%embed_positions.weight : Float(512, 512, strides=[512, 1], requires_grad=0, device=cpu),
%layers.0.self_attn.k_proj.weight : Float(512, 512, strides=[512, 1], requires_grad=1, device=cpu),
After the function, I print the graph:
graph(%decoder_input_embeds : Float(*, *, 512, strides=[1024, 512, 1], requires_grad=0, device=cpu),
%encoder_attention_mask : Long(*, *, strides=[7, 1], requires_grad=0, device=cpu),
%decoder_cache_values : Float(*, 8, *, 64, strides=[1536, 192, 64, 1], requires_grad=0, device=cpu),
%encoder_cache_values : Float(*, 8, *, 64, strides=[3584, 448, 64, 1], requires_grad=0, device=cpu),
%embed_positions.weight : Float(512, 512, strides=[512, 1], requires_grad=0, device=cpu),
So, why the arg encoder_hidden_states
lost? Some approaches to solve this?