Hey there,
I have an ActorCritic (used for PPO) and want to export it as onnx:
ActorCritic(
(action_layer): Sequential(
(0): Linear(in_features=180, out_features=64, bias=True)
(1): Tanh()
(2): Linear(in_features=64, out_features=64, bias=True)
(3): Tanh()
(4): Linear(in_features=64, out_features=60, bias=True)
(5): Softmax(dim=-1)
)
(value_layer): Sequential(
(0): Linear(in_features=180, out_features=64, bias=True)
(1): Tanh()
(2): Linear(in_features=64, out_features=64, bias=True)
(3): Tanh()
(4): Linear(in_features=64, out_features=1, bias=True)
)
)
Code (Snippet - Relevant)
class ActorCritic(nn.Module):
def __init__(self, state_dim, action_dim, n_latent_var):
super(ActorCritic, self).__init__()
# actor
self.action_layer = nn.Sequential(
nn.Linear(state_dim, n_latent_var),
nn.Tanh(),
nn.Linear(n_latent_var, n_latent_var),
nn.Tanh(),
nn.Linear(n_latent_var, action_dim),
nn.Softmax(dim=-1)
)
# critic
self.value_layer = nn.Sequential(
nn.Linear(state_dim, n_latent_var),
nn.Tanh(),
nn.Linear(n_latent_var, n_latent_var),
nn.Tanh(),
nn.Linear(n_latent_var, 1)
)
def forward(self):
raise NotImplementedError
def act(self, state, memory):
state = torch.from_numpy(state).float().to(device)
action_probs = self.action_layer(state)
dist = Categorical(action_probs)
action = dist.sample()
memory.states.append(state)
memory.actions.append(action)
memory.logprobs.append(dist.log_prob(action))
return action.item()
def evaluate(self, state, action):
action_probs = self.action_layer(state)
dist = Categorical(action_probs)
action_logprobs = dist.log_prob(action)
dist_entropy = dist.entropy()
state_value = self.value_layer(state)
return action_logprobs, torch.squeeze(state_value), dist_entropy
lass PPO:
def __init__(self, state_dim, action_dim, n_latent_var, lr, betas, gamma, K_epochs, eps_clip):
self.policy = ActorCritic(state_dim, action_dim, n_latent_var).to(device)
exportONNX(ppo.policy, torch.rand(180), str(reward_mean))
Error:
raceback (most recent call last):
File "ppo_witches.py", line 244, in <module>
main()
File "ppo_witches.py", line 236, in main
exportONNX(ppo.policy, [torch.randn(180), torch.rand(180)], str(reward_mean))
File "ppo_witches.py", line 146, in exportONNX
torch_out = torch.onnx._export(model, input_vector, path+".onnx", export_params=True)
File "/home/mlamprecht/Documents/mcts_cardgame/my_env/lib/python3.6/site-packages/torch/onnx/__init__.py", line 26, in _export
result = utils._export(*args, **kwargs)
File "/home/mlamprecht/Documents/mcts_cardgame/my_env/lib/python3.6/site-packages/torch/onnx/utils.py", line 416, in _export
fixed_batch_size=fixed_batch_size)
File "/home/mlamprecht/Documents/mcts_cardgame/my_env/lib/python3.6/site-packages/torch/onnx/utils.py", line 279, in _model_to_graph
graph, torch_out = _trace_and_get_graph_from_model(model, args, training)
File "/home/mlamprecht/Documents/mcts_cardgame/my_env/lib/python3.6/site-packages/torch/onnx/utils.py", line 236, in _trace_and_get_graph_from_model
trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(model, args, _force_outplace=True, _return_inputs_states=True)
File "/home/mlamprecht/Documents/mcts_cardgame/my_env/lib/python3.6/site-packages/torch/jit/__init__.py", line 277, in _get_trace_graph
outs = ONNXTracedModule(f, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
File "/home/mlamprecht/Documents/mcts_cardgame/my_env/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in __call__
result = self.forward(*input, **kwargs)
File "/home/mlamprecht/Documents/mcts_cardgame/my_env/lib/python3.6/site-packages/torch/jit/__init__.py", line 360, in forward
self._force_outplace,
File "/home/mlamprecht/Documents/mcts_cardgame/my_env/lib/python3.6/site-packages/torch/jit/__init__.py", line 347, in wrapper
outs.append(self.inner(*trace_inputs))
File "/home/mlamprecht/Documents/mcts_cardgame/my_env/lib/python3.6/site-packages/torch/nn/modules/module.py", line 530, in __call__
result = self._slow_forward(*input, **kwargs)
File "/home/mlamprecht/Documents/mcts_cardgame/my_env/lib/python3.6/site-packages/torch/nn/modules/module.py", line 516, in _slow_forward
result = self.forward(*input, **kwargs)
TypeError: forward() takes 1 positional argument but 2 were given