[ONNX] Representation of Embedding operation

I am trying to convert the following Pytorch model I found on GitHub to the ONNX format:

I am working with the first model from the demo3 notebook.

I had some trouble with the dummy input for that, so I’ve written the following code for that:

import torch.onnx
from torch.autograd import Variable 

#To get the right range and order of the embedding inputs
 maxes = []
 for i in range(0, len(train_dataset.deep[0:5])):
     a = [x[i] for x in train_dataset.deep]
     maxes.append(int(max(a)))

#Creating the dummy input
dummy_wide_input = Variable(torch.randn(1, 798))
dummy_deep_input = Variable()
for i, m in enumerate(maxes):
    r = Variable(torch.randint(0, m, (1,)))
    dummy_deep_input = torch.cat((dummy_deep_input, r), 0)
dummy_deep_input = torch.cat((dummy_deep_input, Variable(torch.randn(2))), 0).unsqueeze(0)
dummy_input = (dummy_wide_input, dummy_deep_input)

#Exporting to ONNX
torch.onnx.export(model, dummy_input, 'wideanddeep.onnx')

First I had a problem with the missing ATen index operation and I tried to implement it by myself on torch/onnx/symbolic.py but I am far from confident this implementation is right:

def index(g, *args, **kwargs):
    return g.op("ATen", args[0], operator_s="index")

After this I finally managed to get an ONNX representation, but I really don’t understand how the Embedding operation is shown there. Using the Netron visualizer to view the network, the Embedding is represented like this:

First of all, I don’t get the cast operation casting to 7 and not an actual type. When I try to use onnx.chacker.check on this model I even get an error from that:

ValidationError: Attribute 'to' is expected to have field 's'

Besided that, I don’t see how this series of operations are the embedding at all. Does that make any sense?

I am using Pytorch version 0.5.0a0+21e0fc8 and ONNX version 1.1.1.