Hey there,
I have this model:
class PlayingPolicy(nn.Module):
def __init__(self):
super().__init__()
# Parameters
self.gamma = 0.9 # reward discount factor
# Network
in_channels = 4 # four players
out_channels = 8
self.conv418 = nn.Conv2d(in_channels, out_channels, kernel_size=(3, 4), stride=1, dilation=(1, 8))
self.conv881 = nn.Conv2d(in_channels, out_channels, kernel_size=(3, 8), stride=8, dilation=(1, 1))
self.classify = nn.Linear(104, len(belot.cards))# len(belot.cards)=32 # KARO, HERC, PIK, TREF, dalje
# Optimizer
self.optimizer = optim.SGD(self.parameters(), lr=1e-2, momentum=0.9)
# ...
def forward(self, state: np.ndarray, bidder, trump, legalCards):
# datatype: <class 'numpy.ndarray'>, <enum 'Suit'>, <class 'int'>, <class 'list'>
# ....
and when I try to
input_nn = (playingState, bidderIndex, trumpIndex, [])
torch.onnx._export(policy, input_nn, path, export_params=True)
I get this error:
Traceback (most recent call last):
File "train.py", line 64, in <module>
train_player()
File "train.py", line 38, in train_player
game.saveNetworks()
File "/home/markus/Documents/06_Software_Projects/belot/game/play.py", line 402, in saveNetworks
player.saveNetwork()
File "/home/markus/Documents/06_Software_Projects/belot/players/PlayerRL/player.py", line 183, in saveNetwork
torch.onnx._export(policy, input_nn, path, export_params=True)
File "/home/markus/Documents/06_Software_Projects/belot/belot_env/lib/python3.6/site-packages/torch/onnx/__init__.py", line 26, in _export
result = utils._export(*args, **kwargs)
File "/home/markus/Documents/06_Software_Projects/belot/belot_env/lib/python3.6/site-packages/torch/onnx/utils.py", line 416, in _export
fixed_batch_size=fixed_batch_size)
File "/home/markus/Documents/06_Software_Projects/belot/belot_env/lib/python3.6/site-packages/torch/onnx/utils.py", line 279, in _model_to_graph
graph, torch_out = _trace_and_get_graph_from_model(model, args, training)
File "/home/markus/Documents/06_Software_Projects/belot/belot_env/lib/python3.6/site-packages/torch/onnx/utils.py", line 236, in _trace_and_get_graph_from_model
trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(model, args, _force_outplace=True, _return_inputs_states=True)
File "/home/markus/Documents/06_Software_Projects/belot/belot_env/lib/python3.6/site-packages/torch/jit/__init__.py", line 277, in _get_trace_graph
outs = ONNXTracedModule(f, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
File "/home/markus/Documents/06_Software_Projects/belot/belot_env/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in __call__
result = self.forward(*input, **kwargs)
File "/home/markus/Documents/06_Software_Projects/belot/belot_env/lib/python3.6/site-packages/torch/jit/__init__.py", line 332, in forward
in_vars, in_desc = _flatten(args)
RuntimeError: Only tuples, lists and Variables supported as JIT inputs/outputs. Dictionaries and strings are also accepted but their usage is not recommended. But got unsupported type numpy.ndarray
any ideas how to solve this one?
I alreay checked here but it did not help me…
PS: My Model
PlayingPolicy(
(conv418): Conv2d(4, 8, kernel_size=(3, 4), stride=(1, 1), dilation=(1, 8))
(conv881): Conv2d(4, 8, kernel_size=(3, 8), stride=(8, 8))
(classify): Linear(in_features=104, out_features=32, bias=True)
(criterion): PolicyGradientLoss()
)