Adversarial Text Generation, using discrete sampling

I have a seq2seq model, which I want to train in an adversarial setting. At each timestep, I want to do the following:

  • Obtain a probability distribution over the words of the vocabulary, from the decoder logits.
  • Sample a word from the aforementioned distribution.
  • Feed the sampled word, back to the model (Scheduled Sampling)
  • Feed the sequence of words produced by the decoder to a discriminator.

I need my model to be differentiable end-to-end. In order to “soft-sample” words, I use as word representations, a weighted sum of all the word embeddings, parameterized by the softmax of the logits of decoder from each timestep (like in Goyal - ‎2017).

Ideally, I want to sample discrete words, but backpropagate as if I have sampled from the softmax of the logits, in order to make my model differentiable. I picked this trick from here: https://github.com/pytorch/pytorch/blob/425ea6b31e433eaf1e4aa874332c1d6176cc6a8a/torch/nn/functional.py#L1025

And this is my implementation of what i described:

def softmax_discrete(self, logits, tau=1):
    y_soft = F.softmax(logits.squeeze() / tau, dim=1)
    shape = logits.size()
    _, k = y_soft.max(-1)
    y_hard = logits.new_zeros(*shape).scatter_(-1, k.view(-1, 1), 1.0)
    y = y_hard - y_soft.detach() + y_soft
    return y

def _current_emb(self, logits, tau):
	dist = self.softmax_discrete(logits.squeeze(), tau)
	e_i = dist.mm(self.embedding.weight)
	return e_i.unsqueeze(1)

Is this implementation correct? Since I am computing a wheigted sum of the word embeddings, using a discrete distribution, in which most embeddings are multiplied with zero, will the gradients be non zero (as if the distribution was not discrete), or not?

My model is not performing very well and I would like to rule out this part as one of the reasons.

Thanks!