Torch.jit.trace_module: how to capture modification of passed dict?

Fairseq Transformer models modify the passed incremental state dict without returning it (OTOH huggingface Transformers return it). ONNX export seems to drop these operations. Is it possible to still export these correctly?

import torch
from typing import Any, Dict, List, Optional
from torch import Tensor
 
class InplaceModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.embeddings = torch.nn.Linear(1, 10, bias=False)
        
    def forward(self, x, state: Dict[str, Dict[str, Tensor]]):
        state['new_level_1'] = dict()
        state['new_level_1']['new_level_2'] = x
        return self.embeddings(x)
    
dummy_state = {'level_1': {'level_2': torch.tensor([[3]], dtype=torch.float32)}}
dummy_state = torch.jit.annotate(Dict[str, Dict[str, Optional[torch.Tensor]]], dummy_state)
dummy_input = torch.tensor([[3]], dtype=torch.float32)
model = torch.jit.trace_module(InplaceModule(), dict(forward=(dummy_input, dummy_state,)))
print(model.code)
def forward(self,
    input: Tensor,
    argument_2: Dict[str, Dict[str, Tensor]]) -> Tensor:
  return (self.embeddings).forward(input, )