Differentiable argmax

If you use indexing / embedding, the index doesn’t have a gradient, so I can probably learn a trick if you share how you use the index when it works.

By “it doesn’t do gradients on the weight” I just meant that my function above won’t return gradients for the embedding_matrix (weight in the embedding layer…), sorry for the confusion!

Best regards

Thomas