Embedding after Gumbel-Softmax


I am trying on a model while during training one of the step is to sample some sequence and I need to be able to backpropagate through this step.
To do so I am sampling using F.gumbel_softmax(logits, tau=1, hard=True, dim=2)
My problem is that I need to evaluate some score on this sampled sequences, and to do so I need to plug them back inside the model, which start by an Embedding layer (which needs the input to be in the format of indices not onehot) and doing an argmax to get the indices of the 1, will break the backprop.

Clearly it is computable, but I don t really know how to do that nicely… I was thinking of taking the weights of the nn.Embedding and do some matrix multiplication…

I would be thanksfull for any advice.


If your embedding is standard then you can use torch.matmul as seen here https://github.com/facebookresearch/EGG/blob/170e5fe63c13244121a5b29b9bfb4870a0f11796/egg/core/gs_wrappers.py#L203

Otherwise, you have to also implement the Embedding. E.g. if your padding_idx != 0, I’m not sure but I think you have to figure out how to make it not contribute to the gradient