Export Actor Critic to onnx fails

Hey there,

I have an ActorCritic (used for PPO) and want to export it as onnx:

ActorCritic(
(action_layer): Sequential(
(0): Linear(in_features=180, out_features=64, bias=True)
(1): Tanh()
(2): Linear(in_features=64, out_features=64, bias=True)
(3): Tanh()
(4): Linear(in_features=64, out_features=60, bias=True)
(5): Softmax(dim=-1)
)
(value_layer): Sequential(
(0): Linear(in_features=180, out_features=64, bias=True)
(1): Tanh()
(2): Linear(in_features=64, out_features=64, bias=True)
(3): Tanh()
(4): Linear(in_features=64, out_features=1, bias=True)
)
)

Code (Snippet - Relevant)

class ActorCritic(nn.Module):
    def __init__(self, state_dim, action_dim, n_latent_var):
        super(ActorCritic, self).__init__()

        # actor
        self.action_layer = nn.Sequential(
                nn.Linear(state_dim, n_latent_var),
                nn.Tanh(),
                nn.Linear(n_latent_var, n_latent_var),
                nn.Tanh(),
                nn.Linear(n_latent_var, action_dim),
                nn.Softmax(dim=-1)
                )

        # critic
        self.value_layer = nn.Sequential(
                nn.Linear(state_dim, n_latent_var),
                nn.Tanh(),
                nn.Linear(n_latent_var, n_latent_var),
                nn.Tanh(),
                nn.Linear(n_latent_var, 1)
                )

    def forward(self):
        raise NotImplementedError

    def act(self, state, memory):
        state = torch.from_numpy(state).float().to(device)
        action_probs = self.action_layer(state)
        dist = Categorical(action_probs)
        action = dist.sample()

        memory.states.append(state)
        memory.actions.append(action)
        memory.logprobs.append(dist.log_prob(action))

        return action.item()

    def evaluate(self, state, action):
        action_probs = self.action_layer(state)
        dist = Categorical(action_probs)

        action_logprobs = dist.log_prob(action)
        dist_entropy = dist.entropy()

        state_value = self.value_layer(state)

        return action_logprobs, torch.squeeze(state_value), dist_entropy

lass PPO:
    def __init__(self, state_dim, action_dim, n_latent_var, lr, betas, gamma, K_epochs, eps_clip):
        self.policy = ActorCritic(state_dim, action_dim, n_latent_var).to(device)



exportONNX(ppo.policy, torch.rand(180), str(reward_mean))

Error:

raceback (most recent call last):
  File "ppo_witches.py", line 244, in <module>
    main()
  File "ppo_witches.py", line 236, in main
    exportONNX(ppo.policy, [torch.randn(180), torch.rand(180)], str(reward_mean))
  File "ppo_witches.py", line 146, in exportONNX
    torch_out = torch.onnx._export(model, input_vector, path+".onnx",  export_params=True)
  File "/home/mlamprecht/Documents/mcts_cardgame/my_env/lib/python3.6/site-packages/torch/onnx/__init__.py", line 26, in _export
    result = utils._export(*args, **kwargs)
  File "/home/mlamprecht/Documents/mcts_cardgame/my_env/lib/python3.6/site-packages/torch/onnx/utils.py", line 416, in _export
    fixed_batch_size=fixed_batch_size)
  File "/home/mlamprecht/Documents/mcts_cardgame/my_env/lib/python3.6/site-packages/torch/onnx/utils.py", line 279, in _model_to_graph
    graph, torch_out = _trace_and_get_graph_from_model(model, args, training)
  File "/home/mlamprecht/Documents/mcts_cardgame/my_env/lib/python3.6/site-packages/torch/onnx/utils.py", line 236, in _trace_and_get_graph_from_model
    trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(model, args, _force_outplace=True, _return_inputs_states=True)
  File "/home/mlamprecht/Documents/mcts_cardgame/my_env/lib/python3.6/site-packages/torch/jit/__init__.py", line 277, in _get_trace_graph
    outs = ONNXTracedModule(f, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
  File "/home/mlamprecht/Documents/mcts_cardgame/my_env/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/mlamprecht/Documents/mcts_cardgame/my_env/lib/python3.6/site-packages/torch/jit/__init__.py", line 360, in forward
    self._force_outplace,
  File "/home/mlamprecht/Documents/mcts_cardgame/my_env/lib/python3.6/site-packages/torch/jit/__init__.py", line 347, in wrapper
    outs.append(self.inner(*trace_inputs))
  File "/home/mlamprecht/Documents/mcts_cardgame/my_env/lib/python3.6/site-packages/torch/nn/modules/module.py", line 530, in __call__
    result = self._slow_forward(*input, **kwargs)
  File "/home/mlamprecht/Documents/mcts_cardgame/my_env/lib/python3.6/site-packages/torch/nn/modules/module.py", line 516, in _slow_forward
    result = self.forward(*input, **kwargs)
TypeError: forward() takes 1 positional argument but 2 were given

The export method tries to call forward, which is currently not taking any input (besides self) in your implementation and also isn’t implemented.

Would it work, if you pass another flag (e.g. a bool) to forward and then fork to act or evaluate?

Did not know that :slight_smile:

Had to do these changes:

    def forward(self, state_input):
        return torch.tensor(self.act(state_input, None))

    def act(self, state, memory):
        if type(state) is np.ndarray:
            state = torch.from_numpy(state).float().to(device)
        action_probs = self.action_layer(state)
        # here make a filter for only possible actions!
        #probs = probs * memory.leagalCards
        dist = Categorical(action_probs)

        action = dist.sample()

        if memory is not None:
            memory.states.append(state)
            memory.actions.append(action)
            memory.logprobs.append(dist.log_prob(action))

        return action.item()