Implementing Truncated Normal Initializer

@uapatira - that’s why I asked this question in the first place :slight_smile:
@alexis-jacq - Thanks for the great answer!

2 Likes

@alexis-jacq This approximation is not close to the true distribution.

Well it seems ok with plots:

in blue, the normal distribution (sorted)
in red, the truncated distribution (also sorted)

2 Likes

Well, the method confines the random variables to a circle when you want a square [Edit: The square would not work either as it has the wrong marginals, too. One needs the full “other dimension” for each coordinate]. It probably works just fine for the initialisation heuristic, but it is far from general purpose.
(Interestingly, a similar transformation to polar coordinates is the easiest way I know to compute the normalisation constant in the normal transform. If you could apply it to squares you would get a closed form for the normal CDF.)

Best regards

Thomas

I don’t understand why do you want a square… Don’t we want to confine into a circle of ray 2? And anyway, here my example is in 1d, so the circle and the square are the same, it’s the segment [-2, 2]. The box-muller method for truncated normal is already known and used, since this paper: http://oa.upm.es/22659/1/INVE_MEM_2012_153251.pdf

@alexis-jacq I haven’t looked at the paper. But from my simulation, the sampled points are not following the true truncated normal distribution,

The red line is the pdf given by scipy.stats.truncnorm

2 Likes

This seems to be the picture I had in mind when I said that it was with cutting off too much by restricting to the circle…

If you can do with an approximation, torch.fmod(torch.randn(size),2) might be easiest.
It moves more of the cut-off mass to the center than to the boundaries, but at least the good samples stay where they belong…

8 Likes

Wow, and your approximation takes one line!

So it looks like you take all the element over 2 and put them close to the mean, like in a snake game, and that’s why you have still too many people close to the mean. And it seems that the worst case comes with variance = 1. I would worth to write down the mathematics, to see how far we are from a Gaussian with both methods

Hi Alexis,

Well, I think one of the things is that your sketch seems different from the usual definition of truncated Gaussian.
So what the .fmod in sampling does in terms of pdf is to sum the pdf over all x+sign(x)*cutoff*i | i=0,1,2,… . If the shape of pdf in [cutoff*i, cutoff*(i+1)] were proportional to [0,cutoff], this would be exact. As it is falls of much more quickly than proportional (a 3 std event conditional on the magnitude being at least 2 std is much less likely than a 1 std event). This is why summation puts too much mass in the center. With a smaller cutoff, the pdf in [-cutoff,cutoff] becomes flatter (closer to uniform) and thus the error is more pronounced. And this would further increase if you cut of at less than 1 std, but at some point you might just use a uniform distribution anyway. Even at 1 you could use a cutoff of 2/3 (or so) and then multiply the sample by 3/2 to approximate the truncated normal better (i.e. torch.fmod(torch.randn(size),2/3)*3/2).

In the plain Box-Muller method, you are essentially using two transforms: first that for U_1 uniform on [0,1] R = -sqrt(2 ln(U_1)) has the distribution of the radii in a 2d standard normal, χ(2). Thus that X,Y = sin/cos(U_2) * R will be 2d standard normal. Then you are using that the marginals of the 2d standard normal are 1d standard normals.

If you now restrict R to R <= c, and take the pdf of X close to c, you have only very little of the total mass along the vertical (Y-coordinate) lies in the disc R<=c. This is why the “restricted BM” distribution has the density going to 0 as X approaches c.

But then I would be surprised if the initialisation heuristic broke down on with a simple approximation.

Best regards

Thomas

1 Like

Ok thanks, now I understand better. So, close to x= 0, all samples should be accepted, since cos(U_2) is small, but with my approach, I still multiply by a restricted R and I make it even smaller. That explains the plot of @ruotianluo where I pass over expected density close to the mean.

It was too beautiful to work…

Once again, I should be more careful when emotions mix with mathematics…

Coming back to this thread after one year.

One simple approximated implementation for truncated_normal to 2*std.

def truncated_normal_(tensor, mean=0, std=1):
    size = tensor.shape
    tmp = tensor.new_empty(size + (4,)).normal_()
    valid = (tmp < 2) & (tmp > -2)
    ind = valid.max(-1, keepdim=True)[1]
    tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1))
    tensor.data.mul_(std).add_(mean)
8 Likes

Hi! Does it implement same thing as https://www.tensorflow.org/api_docs/python/tf/truncated_normal ?

Here is a simpler way to sample values from truncated normal distribution.

from scipy.stats import truncnorm
import torch


def truncated_normal(size, threshold=1):
    values = truncnorm.rvs(-threshold, threshold, size=size)
    return values

# usage example
x= truncnorm([10, 20], threshold=1)   # sample 10x20 sized tensor
x = torch.from_numpy(x).cuda()
5 Likes

Thanks @ruotianluo.
I confirmed that your function is equivalent to scipy truncnorm.

if anyone wants reproduce:

import torch
from scipy.stats import truncnorm
import matplotlib.pyplot as plt

def truncated_normal_(tensor, mean=0, std=1):
    size = tensor.shape
    tmp = tensor.new_empty(size + (4,)).normal_()
    valid = (tmp < 2) & (tmp > -2)
    ind = valid.max(-1, keepdim=True)[1]
    tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1))
    tensor.data.mul_(std).add_(mean)
    return tensor



fig, ax = plt.subplots(1, 1)


def test_truncnorm():
    a, b = -2, 2
    size = 1000000
    r = truncnorm.rvs(a, b, size=size)
    ax.hist(r, density=True, histtype='stepfilled', alpha=0.2, bins=50)

    tensor = torch.zeros(size)
    utils.truncated_normal_(tensor)
    r = tensor.numpy()

    ax.hist(r, density=True, histtype='stepfilled', alpha=0.2, bins=50)
    ax.legend(loc='best', frameon=False)
    plt.show()


if __name__ == '__main__':
    test_truncnorm()

The Tensorflow implementation of truncated_normal can be found on GitHub and references https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf (see page 24) there.

Analogous Numpy/PyTorch code would look like

import numpy as np
import torch


def parameterized_truncated_normal(uniform, mu, sigma, a, b):
    normal = torch.distributions.normal.Normal(0, 1)

    alpha = (a - mu) / sigma
    beta = (b - mu) / sigma

    alpha_normal_cdf = normal.cdf(alpha)
    p = alpha_normal_cdf + (normal.cdf(beta) - alpha_normal_cdf) * uniform

    p = p.numpy()
    one = np.array(1, dtype=p.dtype)
    epsilon = np.array(np.finfo(p.dtype).eps, dtype=p.dtype)
    v = np.clip(2 * p - 1, -one + epsilon, one - epsilon)
    x = mu + sigma * np.sqrt(2) * torch.erfinv(torch.from_numpy(v))
    x = torch.clamp(x, a, b)

    return x


def truncated_normal(uniform):
    return parameterized_truncated_normal(uniform, mu=0.0, sigma=1.0, a=-2, b=2)


def sample_truncated_normal(shape=()):
    return truncted_normal(torch.from_numpy(np.random.uniform(0, 1, shape)))

However in the case of mu=0.0, sigma=1.0, a=-2, b=2, I suspect repeated normal sampling is also fine: 5% of entries need to be resampled each iteration. E.g. since 0.05**5 * 1000000 < 1 one would need roughly 5 iterations for 1M entries.

4 Likes

In case anyone needs a truncated normal distribution class, implementing the torch.distributions.Distribution interface, I put together one here: https://github.com/toshas/torch_truncnorm

2 Likes

Hello @anton, Thanks for posting your take on this. The class was very intuitive to use. But I found that when supplied a value outside of [a,b], log_prob does not give -inf. For example, a distribution with a=0, b=inf, loc=0, scale=1, log_prob(-1) gives -0.7. might be worth looking into.

I admit, this is not the most efficient way of doing it but it is re-drawing the values rather than clipping them as it is described in the tf specification (“values more than two standard deviations from the mean are discarded and re-drawn”)

def truncated_normal(t, mean=0.0, std=0.01):
    torch.nn.init.normal_(t, mean=mean, std=std)
    while True:
      cond = torch.logical_or(t < mean - 2*std, t > mean + 2*std)
      if not torch.sum(cond):
        break
      t = torch.where(cond, torch.nn.init.normal_(torch.ones(t.shape), mean=mean, std=std), t)
    return t

I am using it like this m.weight.data = truncated_normal(m.weight.data) to initialize my weights. It is convenient in my case

1 Like

Use torch.nn.init.trunc_normal_.

Description as given Here:

Fills the input Tensor with values drawn from a truncated
normal distribution. The values are effectively drawn from the
normal distribution :math:\mathcal{N}(\text{mean}, \text{std}^2)
with values outside :math:[a, b] redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:a \leq \text{mean} \leq b.

2 Likes