x = torch.tensor([[1, 2, 3, 4, 5], [6, 7, 8, 9, 0]])
x
tensor([[1, 2, 3, 4, 5],
[6, 7, 8, 9, 0]])
embedded = embedding(x)
embedded
tensor([[[ 0.2623, -0.6216, 1.8112],
[ 0.5538, -1.0141, -1.6522],
[-0.3584, -0.4176, 0.1584],
[ 1.3007, -1.4878, 0.7318],
[ 1.0062, 0.8589, -1.6106]],
[[-0.4756, 1.3294, 0.3395],
[ 0.4187, 0.9915, -0.5702],
[ 0.3695, 0.8041, -0.6507],
[-0.0904, 0.4995, -0.6367],
[ 0.4574, -1.1187, -0.2644]]], grad_fn=<EmbeddingBackward>)