Find all registered KL divergence implementations

I would like to know what are all the possible distributions I can use out of the box in the kl_divergence function. Is there a quick way to figure this out? Thanks!

torch.distributions.kl._KL_REGISTRY dictionary, or look inside kl.py source

1 Like

Thank you for the suggestion @googlebot. Based on that I wrote this function that does what I want:

from torch.distributions.kl import _KL_REGISTRY

def view_kl_options():
    names = [(k[0].__name__, k[1].__name__) for k in _KL_REGISTRY.keys()]
    max_name_len = max([len(t[0]) for t in names])
    for arg1, arg2 in sorted(names):
        print(f"  {arg1:>{max_name_len}} || {arg2}")

This gives an output like:

                  Bernoulli || Bernoulli
                  Bernoulli || Poisson
                       Beta || Beta
                       Beta || ContinuousBernoulli
                       Beta || Exponential
                       Beta || Gamma
                       Beta || Normal
                       Beta || Pareto
                       Beta || Uniform
                   Binomial || Binomial
                Categorical || Categorical
                     Cauchy || Cauchy
        ContinuousBernoulli || ContinuousBernoulli
        ContinuousBernoulli || Exponential
        ContinuousBernoulli || Normal
        ContinuousBernoulli || Pareto
        ContinuousBernoulli || Uniform
                  Dirichlet || Dirichlet
                Exponential || Beta
                Exponential || ContinuousBernoulli
                Exponential || Exponential
                Exponential || Gamma
                Exponential || Gumbel
                Exponential || Normal
                Exponential || Pareto
                Exponential || Uniform
          ExponentialFamily || ExponentialFamily
                      Gamma || Beta
                      Gamma || ContinuousBernoulli
                      Gamma || Exponential
                      Gamma || Gamma
                      Gamma || Gumbel
                      Gamma || Normal
                      Gamma || Pareto
                      Gamma || Uniform
                  Geometric || Geometric
                     Gumbel || Beta
                     Gumbel || ContinuousBernoulli
                     Gumbel || Exponential
                     Gumbel || Gamma
                     Gumbel || Gumbel
                     Gumbel || Normal
                     Gumbel || Pareto
                     Gumbel || Uniform
                 HalfNormal || HalfNormal
                Independent || Independent
                    Laplace || Beta
                    Laplace || ContinuousBernoulli
                    Laplace || Exponential
                    Laplace || Gamma
                    Laplace || Laplace
                    Laplace || Normal
                    Laplace || Pareto
                    Laplace || Uniform
  LowRankMultivariateNormal || LowRankMultivariateNormal
  LowRankMultivariateNormal || MultivariateNormal
         MultivariateNormal || LowRankMultivariateNormal
         MultivariateNormal || MultivariateNormal
                     Normal || Beta
                     Normal || ContinuousBernoulli
                     Normal || Exponential
                     Normal || Gamma
                     Normal || Gumbel
                     Normal || Normal
                     Normal || Pareto
                     Normal || Uniform
          OneHotCategorical || OneHotCategorical
                     Pareto || Beta
                     Pareto || ContinuousBernoulli
                     Pareto || Exponential
                     Pareto || Gamma
                     Pareto || Normal
                     Pareto || Pareto
                     Pareto || Uniform
                    Poisson || Bernoulli
                    Poisson || Binomial
                    Poisson || Poisson
    TransformedDistribution || TransformedDistribution
                    Uniform || Beta
                    Uniform || ContinuousBernoulli
                    Uniform || Exponential
                    Uniform || Gamma
                    Uniform || Gumbel
                    Uniform || Normal
                    Uniform || Pareto
                    Uniform || Uniform