I am trying to convert the following Pytorch model I found on GitHub to the ONNX format:
I am working with the first model from the demo3 notebook.
I had some trouble with the dummy input for that, so I’ve written the following code for that:
import torch.onnx
from torch.autograd import Variable
#To get the right range and order of the embedding inputs
maxes = []
for i in range(0, len(train_dataset.deep[0:5])):
a = [x[i] for x in train_dataset.deep]
maxes.append(int(max(a)))
#Creating the dummy input
dummy_wide_input = Variable(torch.randn(1, 798))
dummy_deep_input = Variable()
for i, m in enumerate(maxes):
r = Variable(torch.randint(0, m, (1,)))
dummy_deep_input = torch.cat((dummy_deep_input, r), 0)
dummy_deep_input = torch.cat((dummy_deep_input, Variable(torch.randn(2))), 0).unsqueeze(0)
dummy_input = (dummy_wide_input, dummy_deep_input)
#Exporting to ONNX
torch.onnx.export(model, dummy_input, 'wideanddeep.onnx')
First I had a problem with the missing ATen index operation and I tried to implement it by myself on torch/onnx/symbolic.py but I am far from confident this implementation is right:
def index(g, *args, **kwargs):
return g.op("ATen", args[0], operator_s="index")
After this I finally managed to get an ONNX representation, but I really don’t understand how the Embedding operation is shown there. Using the Netron visualizer to view the network, the Embedding is represented like this:
First of all, I don’t get the cast operation casting to 7 and not an actual type. When I try to use onnx.chacker.check on this model I even get an error from that:
ValidationError: Attribute 'to' is expected to have field 's'
Besided that, I don’t see how this series of operations are the embedding at all. Does that make any sense?
I am using Pytorch version 0.5.0a0+21e0fc8 and ONNX version 1.1.1.