What is the purpose of 10 in official example?

emb = nn.Embedding(10, 3)
print(emb.weight.shape)
print(emb.weight)
input = torch.LongTensor([[1,2,4,5],[4,3,2,9]])
print('input:', input.shape)
embo = emb(input)
print(embo.shape)
print('embo:',embo)

Here is a demo code that works but I don’t understand what is 10.

Since our input is torch.Size([2, 4]), output is torch.Size([2, 4, 3]) because of the embedding emb.weight.shape[1] is 3.

Can someone explain me in simple words what is the meaning of 10 = emb.weight.shape[1]?

My current assumption one additional hidden layer will be created of 10x10, but I am not sure.

Hi,

You can check the doc.
But this is how many entries there are in the embedding table. And so the maximum index that can be contained in your input.

I cannot really see “the maximum index that can be contained in my input”. If or example:

emb = nn.Embedding(10, 3)
print(emb.weight.shape)
print(emb.weight)
input = torch.LongTensor([[1,2,4,5],[4,3,2,9],[2,3,4,5]])
print('input:', input.shape)
embo = emb(input)
print(embo.shape)
print('embo:',embo)

If I provide the input that has torch.Size([3,4]) elements the embedding output will have torch.Size([3, 4, 3]). So there is no reflection of 10 in the output.

I copied the example from the official document you pointed @albanD so we are on the same page.

Hi,

Not the maximum number of elements. The maximum value in there :slight_smile:
As you can see, everything in there is in [0, 9]. If you try to set one value to 42, you will get an error (anything >= 10 will as it is 0-indexed).

1 Like