Dirichlet distribution sometimes outputs tensor full of NaNs

Hey guys,
Found an issue with Dirichlet distribution. Sometimes, for no apparent reason, it outputs a tensor full of NaNs. The line causing the problem is as simple as:

doc_topics = Dirichlet(topic_weights).rsample()

topic_weights is asserted to be a valid tensor, always.
Now, I’m circumventing the problem by replacing the line above by the following code:

repeat = True
while repeat:
    doc_topics = Dirichlet(topic_weights).rsample()
    repeat = doc_topics.isnan().any()

It works… but it is as ugly as code can possibly be. (In Brazil we have a slang for that: “gambiarra” :slight_smile:)

Any hints on how to solve that?

Thank you for the new word! I’ll be sure to use it next time I need it (and around computers, that will be soon…).

The obvious solution is to fix PyTorch to not do that. :slight_smile: But can you give us some more idea about topic_weights (is it cuda, what shape, something unusual about the values)? The perfect reproducing example would allow us to generate the NaNs (by giving the weights and a random seed/rng state that hits the problem). A bug report with a reproducing case would be great help! (Or just paste it here and we can forward it.)
There have been issues with very large/small parameters for some distributions, but a very quick search doesn’t scream the Dirichlet issue is known.

Best regards


Thanks @tom! I’m trying to create a reproducing example to work in Colab… will share soon… Thanks!