Dirichlet/Gamma Distribution Sampling And Reparam. Gradients

I’m attempting to find the methods used to sample from, and compute the reparameterized gradients of, the Dirichilet and Gamma distributions.

The Dirichlet distribution and the Gamma Distribution both include seemingly hidden functions.

For Dirichlet: torch._sample_dirichlet, torch._dirichlet_grad,

For Gamma: torch._standard_gamma

Is there documentation on the sampling/gradient methods, and these functions?

1 Like

Bumping this. @ptrblck, @smth do you folks know who the correct person to ask about this would be?

Hey Max.

Looks like these functions were never properly documented, but the CPU implementations for these are pretty readable.

This is where _sample_dirichlet goes to: pytorch/aten/src/ATen/native/Distributions.cpp at a6ac6447b55bcf910dee5f925c2c17673f162a36 · pytorch/pytorch · GitHub . I figured this out by looking for _sample_dirichlet in the native_functions.yaml file.

This is where _standard_gamma goes to: pytorch/aten/src/ATen/native/Distributions.cpp at a6ac6447b55bcf910dee5f925c2c17673f162a36 · pytorch/pytorch · GitHub

For both of these, the core gamma sampling function that they call is implemented here (lifted straight up from NumPy sources): pytorch/aten/src/ATen/native/Distributions.h at a6ac6447b55bcf910dee5f925c2c17673f162a36 · pytorch/pytorch · GitHub

Hope this helps.

Yes, this is very helpful! Thanks for diving into the belly of the beast for me…I would have never found these!

Can I assume the GPU implementations are in effect the ~same algorithms?

yes the gpu implementations are the same algorithms