Sample from GumbelSoftmax in a way that track gradients

I’m working with Pyro right now on using a VAE for seq2seq name generation and have this bit of code

sampled_indexes = pyro.sample(f"{address}_{index}", pyro.distributions.RelaxedOneHotCategoricalStraightThrough(1, logits=char_dist), obs=observed[index])

Essentially char_dist is a categorical distribution from a neural network for the character in a name as an index i. “pyro.sample” expects a distribution that samples a value, in this case RelaxedOneHotCategoricalStraightThrough samples a categorical distribution. I’d rather be able to do GumbelSoftmax PyTorch distribution that just samples the value that softmaxes to 1, this is better for Pyro to track the sample, as opposed to sampling a categorical distribution over characters. The reason for this is because if it doesn’t sample from the gumbel softmax an exact value I don’t think it’ll track the gradients during backprop. Any suggestions?