Does export to ONNX support models like this? (that contain some modifiable state and later allow it to be retrieved. the end goal would be to export a transformer decoding beam search algorithm)
import torch
from typing import Any, Dict, List, Optional
from torch import Tensor
class Model(torch.nn.Module):
state: Dict[str, Tensor]
def __init__(self):
super().__init__()
self.state = {}
def forward(self, x):
self.state['abc'] = x
return x
def get_state(self):
return self.state
dummy_input = torch.tensor([[3]], dtype=torch.float32)
model = torch.jit.trace_module(Model(), dict(forward=(dummy_input,), get_state=tuple() ))
print(model.code)
Traceback (most recent call last):
File "bug.py", line 20, in <module>
model = torch.jit.trace_module(Model(), dict(forward=(dummy_input,), get_state=tuple() ))
File "/miniconda/lib/python3.8/site-packages/torch/jit/__init__.py", line 1118, in trace_module
_check_trace([inputs], func, check_trace_method,
File "/miniconda/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 15, in decorate_context
return func(*args, **kwargs)
File "/miniconda/lib/python3.8/site-packages/torch/jit/__init__.py", line 598, in _check_trace
check_mod = torch.jit.trace_module(
File "/miniconda/lib/python3.8/site-packages/torch/jit/__init__.py", line 1109, in trace_module
module._c._create_method_from_trace(method_name, func, example_inputs, var_lookup_fn, strict, _force_outplace)
RuntimeError: Encountering a dict at the output of the tracer might cause the trace to be incorrect, this is only valid if the container structure does not change based on the module's inputs. Consider using a constant container instead (e.g. for `list`, use a `tuple` instead. for `dict`, use a `NamedTuple` instead). If you absolutely need this and know the side effects, pass strict=False to trace() to allow this behavior.