Lost args when export ONNX model

I try to export a Decoder model of a Seq2Seq model. The definition of the class is as follow:

class Decoder(MarianPreTrainedModel):
    def __init__(self, config: MarianConfig):
        super().__init__(config)
        self.dropout = config.dropout
        self.layerdrop = config.decoder_layerdrop
        self.padding_idx = config.pad_token_id
        self.max_target_positions = config.max_position_embeddings

        self.embed_positions = MarianSinusoidalPositionalEmbedding(
            config.max_position_embeddings,
            config.d_model,
            self.padding_idx,
        )
        self.layers = nn.ModuleList([MarianDecoderLayer(config) for _ in range(config.decoder_layers)])
        self.init_weights()
    ....
    def forward(
        self,
        inputs_embeds: torch.Tensor,
        encoder_hidden_states: torch.Tensor,
        encoder_attention_mask: torch.Tensor,
        decoder_past_key_values: PAST_KEY_VALUES_TYPE = None,
        encoder_past_key_values: PAST_KEY_VALUES_TYPE = None,
    ):
    ....

The Module is modified from huggingface MarianMTModel.
When I export the model to ONNX, the eported model lost the arg encoder_hidden_states.
I have debug it in the export funtion, and I find the arg lost after this line:

params_dict = torch._C._jit_pass_onnx_eliminate_unused_items(graph, params_dict)

The 509th line in torch.onnx.utils.py
Before the funtion, I print the graph:

graph(%decoder_input_embeds : Float(*, *, 512, strides=[1024, 512, 1], requires_grad=0, device=cpu),
      %encoder_hidden_states : Float(*, *, 512, strides=[3584, 512, 1], requires_grad=0, device=cpu),
      %encoder_attention_mask : Long(*, *, strides=[7, 1], requires_grad=0, device=cpu),
      %decoder_cache_values : Float(*, 8, *, 64, strides=[1536, 192, 64, 1], requires_grad=0, device=cpu),
      %encoder_cache_values : Float(*, 8, *, 64, strides=[3584, 448, 64, 1], requires_grad=0, device=cpu),
      %embed_positions.weight : Float(512, 512, strides=[512, 1], requires_grad=0, device=cpu),
      %layers.0.self_attn.k_proj.weight : Float(512, 512, strides=[512, 1], requires_grad=1, device=cpu),

After the function, I print the graph:

graph(%decoder_input_embeds : Float(*, *, 512, strides=[1024, 512, 1], requires_grad=0, device=cpu),
      %encoder_attention_mask : Long(*, *, strides=[7, 1], requires_grad=0, device=cpu),
      %decoder_cache_values : Float(*, 8, *, 64, strides=[1536, 192, 64, 1], requires_grad=0, device=cpu),
      %encoder_cache_values : Float(*, 8, *, 64, strides=[3584, 448, 64, 1], requires_grad=0, device=cpu),
      %embed_positions.weight : Float(512, 512, strides=[512, 1], requires_grad=0, device=cpu),

So, why the arg encoder_hidden_states lost? Some approaches to solve this?