Implementing Truncated Normal Initializer

Wow, and your approximation takes one line!

So it looks like you take all the element over 2 and put them close to the mean, like in a snake game, and that’s why you have still too many people close to the mean. And it seems that the worst case comes with variance = 1. I would worth to write down the mathematics, to see how far we are from a Gaussian with both methods

Hi Alexis,

Well, I think one of the things is that your sketch seems different from the usual definition of truncated Gaussian.
So what the .fmod in sampling does in terms of pdf is to sum the pdf over all x+sign(x)*cutoff*i | i=0,1,2,… . If the shape of pdf in [cutoff*i, cutoff*(i+1)] were proportional to [0,cutoff], this would be exact. As it is falls of much more quickly than proportional (a 3 std event conditional on the magnitude being at least 2 std is much less likely than a 1 std event). This is why summation puts too much mass in the center. With a smaller cutoff, the pdf in [-cutoff,cutoff] becomes flatter (closer to uniform) and thus the error is more pronounced. And this would further increase if you cut of at less than 1 std, but at some point you might just use a uniform distribution anyway. Even at 1 you could use a cutoff of 2/3 (or so) and then multiply the sample by 3/2 to approximate the truncated normal better (i.e. torch.fmod(torch.randn(size),2/3)*3/2).

In the plain Box-Muller method, you are essentially using two transforms: first that for U_1 uniform on [0,1] R = -sqrt(2 ln(U_1)) has the distribution of the radii in a 2d standard normal, χ(2). Thus that X,Y = sin/cos(U_2) * R will be 2d standard normal. Then you are using that the marginals of the 2d standard normal are 1d standard normals.

If you now restrict R to R <= c, and take the pdf of X close to c, you have only very little of the total mass along the vertical (Y-coordinate) lies in the disc R<=c. This is why the “restricted BM” distribution has the density going to 0 as X approaches c.

But then I would be surprised if the initialisation heuristic broke down on with a simple approximation.

Best regards

Thomas

1 Like

Ok thanks, now I understand better. So, close to x= 0, all samples should be accepted, since cos(U_2) is small, but with my approach, I still multiply by a restricted R and I make it even smaller. That explains the plot of @ruotianluo where I pass over expected density close to the mean.

It was too beautiful to work…

Once again, I should be more careful when emotions mix with mathematics…

Coming back to this thread after one year.

One simple approximated implementation for truncated_normal to 2*std.

def truncated_normal_(tensor, mean=0, std=1):
    size = tensor.shape
    tmp = tensor.new_empty(size + (4,)).normal_()
    valid = (tmp < 2) & (tmp > -2)
    ind = valid.max(-1, keepdim=True)[1]
    tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1))
    tensor.data.mul_(std).add_(mean)
8 Likes

Hi! Does it implement same thing as https://www.tensorflow.org/api_docs/python/tf/truncated_normal ?

Here is a simpler way to sample values from truncated normal distribution.

from scipy.stats import truncnorm
import torch


def truncated_normal(size, threshold=1):
    values = truncnorm.rvs(-threshold, threshold, size=size)
    return values

# usage example
x= truncnorm([10, 20], threshold=1)   # sample 10x20 sized tensor
x = torch.from_numpy(x).cuda()
5 Likes

Thanks @ruotianluo.
I confirmed that your function is equivalent to scipy truncnorm.

if anyone wants reproduce:

import torch
from scipy.stats import truncnorm
import matplotlib.pyplot as plt

def truncated_normal_(tensor, mean=0, std=1):
    size = tensor.shape
    tmp = tensor.new_empty(size + (4,)).normal_()
    valid = (tmp < 2) & (tmp > -2)
    ind = valid.max(-1, keepdim=True)[1]
    tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1))
    tensor.data.mul_(std).add_(mean)
    return tensor



fig, ax = plt.subplots(1, 1)


def test_truncnorm():
    a, b = -2, 2
    size = 1000000
    r = truncnorm.rvs(a, b, size=size)
    ax.hist(r, density=True, histtype='stepfilled', alpha=0.2, bins=50)

    tensor = torch.zeros(size)
    utils.truncated_normal_(tensor)
    r = tensor.numpy()

    ax.hist(r, density=True, histtype='stepfilled', alpha=0.2, bins=50)
    ax.legend(loc='best', frameon=False)
    plt.show()


if __name__ == '__main__':
    test_truncnorm()

The Tensorflow implementation of truncated_normal can be found on GitHub and references https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf (see page 24) there.

Analogous Numpy/PyTorch code would look like

import numpy as np
import torch


def parameterized_truncated_normal(uniform, mu, sigma, a, b):
    normal = torch.distributions.normal.Normal(0, 1)

    alpha = (a - mu) / sigma
    beta = (b - mu) / sigma

    alpha_normal_cdf = normal.cdf(alpha)
    p = alpha_normal_cdf + (normal.cdf(beta) - alpha_normal_cdf) * uniform

    p = p.numpy()
    one = np.array(1, dtype=p.dtype)
    epsilon = np.array(np.finfo(p.dtype).eps, dtype=p.dtype)
    v = np.clip(2 * p - 1, -one + epsilon, one - epsilon)
    x = mu + sigma * np.sqrt(2) * torch.erfinv(torch.from_numpy(v))
    x = torch.clamp(x, a, b)

    return x


def truncated_normal(uniform):
    return parameterized_truncated_normal(uniform, mu=0.0, sigma=1.0, a=-2, b=2)


def sample_truncated_normal(shape=()):
    return truncted_normal(torch.from_numpy(np.random.uniform(0, 1, shape)))

However in the case of mu=0.0, sigma=1.0, a=-2, b=2, I suspect repeated normal sampling is also fine: 5% of entries need to be resampled each iteration. E.g. since 0.05**5 * 1000000 < 1 one would need roughly 5 iterations for 1M entries.

4 Likes

In case anyone needs a truncated normal distribution class, implementing the torch.distributions.Distribution interface, I put together one here: https://github.com/toshas/torch_truncnorm

2 Likes

Hello @anton, Thanks for posting your take on this. The class was very intuitive to use. But I found that when supplied a value outside of [a,b], log_prob does not give -inf. For example, a distribution with a=0, b=inf, loc=0, scale=1, log_prob(-1) gives -0.7. might be worth looking into.

I admit, this is not the most efficient way of doing it but it is re-drawing the values rather than clipping them as it is described in the tf specification (“values more than two standard deviations from the mean are discarded and re-drawn”)

def truncated_normal(t, mean=0.0, std=0.01):
    torch.nn.init.normal_(t, mean=mean, std=std)
    while True:
      cond = torch.logical_or(t < mean - 2*std, t > mean + 2*std)
      if not torch.sum(cond):
        break
      t = torch.where(cond, torch.nn.init.normal_(torch.ones(t.shape), mean=mean, std=std), t)
    return t

I am using it like this m.weight.data = truncated_normal(m.weight.data) to initialize my weights. It is convenient in my case

1 Like

Use torch.nn.init.trunc_normal_.

Description as given Here:

Fills the input Tensor with values drawn from a truncated
normal distribution. The values are effectively drawn from the
normal distribution :math:\mathcal{N}(\text{mean}, \text{std}^2)
with values outside :math:[a, b] redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:a \leq \text{mean} \leq b.

2 Likes

It seems torch.nn.init.trunc_normal_ does not appear in the documentation of TORCH.NN.INIT, so I am a little confused whether it is stable version of this method?

2 Likes

As of March 30, '22. torch.nn.init.trunc_normal_ is still not in docs. Is there any plan to note them down?

The doc string seems to be properly defined here and you can access it also in IPython so I’m wondering if the docs generation is somehow missing it (@sachin_yadav also pointed this out already).

CC @albanD do you know where the docs generation might fail?

It should be added to this file: pytorch/docs/source/nn.init.rst at 0765a804911673fb2d9694a76ba0196ea0eddec4 · pytorch/pytorch · GitHub

@quocdat32461997 if you want to send a PR fixing that, you can add me as a reviewer!

1 Like

Hi, I sent a PR #76530 to fix this. Because this is my first pr to pytorch, any suggestions will be appreciated.

1 Like

Looks good to me, thank you @baudzhou .

Currently, torch.nn.init.trunc_normal_ does not appear in the documentation of TORCH.NN.INIT . So, is there a timeline for its updates in the documentation?

1 Like