I am trying on a model while during training one of the step is to sample some sequence and I need to be able to backpropagate through this step.
To do so I am sampling using
F.gumbel_softmax(logits, tau=1, hard=True, dim=2)
My problem is that I need to evaluate some score on this sampled sequences, and to do so I need to plug them back inside the model, which start by an Embedding layer (which needs the input to be in the format of indices not onehot) and doing an argmax to get the indices of the
1, will break the backprop.
Clearly it is computable, but I don t really know how to do that nicely… I was thinking of taking the weights of the
nn.Embedding and do some matrix multiplication…
I would be thanksfull for any advice.