So it looks like you take all the element over 2 and put them close to the mean, like in a snake game, and that’s why you have still too many people close to the mean. And it seems that the worst case comes with variance = 1. I would worth to write down the mathematics, to see how far we are from a Gaussian with both methods
Well, I think one of the things is that your sketch seems different from the usual definition of truncated Gaussian.
So what the .fmod in sampling does in terms of pdf is to sum the pdf over all x+sign(x)*cutoff*i | i=0,1,2,… . If the shape of pdf in [cutoff*i, cutoff*(i+1)] were proportional to [0,cutoff], this would be exact. As it is falls of much more quickly than proportional (a 3 std event conditional on the magnitude being at least 2 std is much less likely than a 1 std event). This is why summation puts too much mass in the center. With a smaller cutoff, the pdf in [-cutoff,cutoff] becomes flatter (closer to uniform) and thus the error is more pronounced. And this would further increase if you cut of at less than 1 std, but at some point you might just use a uniform distribution anyway. Even at 1 you could use a cutoff of 2/3 (or so) and then multiply the sample by 3/2 to approximate the truncated normal better (i.e. torch.fmod(torch.randn(size),2/3)*3/2).
In the plain Box-Muller method, you are essentially using two transforms: first that for U_1 uniform on [0,1] R = -sqrt(2 ln(U_1)) has the distribution of the radii in a 2d standard normal, χ(2). Thus that X,Y = sin/cos(U_2) * R will be 2d standard normal. Then you are using that the marginals of the 2d standard normal are 1d standard normals.
If you now restrict R to R <= c, and take the pdf of X close to c, you have only very little of the total mass along the vertical (Y-coordinate) lies in the disc R<=c. This is why the “restricted BM” distribution has the density going to 0 as X approaches c.
But then I would be surprised if the initialisation heuristic broke down on with a simple approximation.
Ok thanks, now I understand better. So, close to x= 0, all samples should be accepted, since cos(U_2) is small, but with my approach, I still multiply by a restricted R and I make it even smaller. That explains the plot of @ruotianluo where I pass over expected density close to the mean.
It was too beautiful to work…
Once again, I should be more careful when emotions mix with mathematics…
import numpy as np
import torch
def parameterized_truncated_normal(uniform, mu, sigma, a, b):
normal = torch.distributions.normal.Normal(0, 1)
alpha = (a - mu) / sigma
beta = (b - mu) / sigma
alpha_normal_cdf = normal.cdf(alpha)
p = alpha_normal_cdf + (normal.cdf(beta) - alpha_normal_cdf) * uniform
p = p.numpy()
one = np.array(1, dtype=p.dtype)
epsilon = np.array(np.finfo(p.dtype).eps, dtype=p.dtype)
v = np.clip(2 * p - 1, -one + epsilon, one - epsilon)
x = mu + sigma * np.sqrt(2) * torch.erfinv(torch.from_numpy(v))
x = torch.clamp(x, a, b)
return x
def truncated_normal(uniform):
return parameterized_truncated_normal(uniform, mu=0.0, sigma=1.0, a=-2, b=2)
def sample_truncated_normal(shape=()):
return truncted_normal(torch.from_numpy(np.random.uniform(0, 1, shape)))
However in the case of mu=0.0, sigma=1.0, a=-2, b=2, I suspect repeated normal sampling is also fine: 5% of entries need to be resampled each iteration. E.g. since 0.05**5 * 1000000 < 1 one would need roughly 5 iterations for 1M entries.
In case anyone needs a truncated normal distribution class, implementing the torch.distributions.Distribution interface, I put together one here: https://github.com/toshas/torch_truncnorm
Hello @anton, Thanks for posting your take on this. The class was very intuitive to use. But I found that when supplied a value outside of [a,b], log_prob does not give -inf. For example, a distribution with a=0, b=inf, loc=0, scale=1, log_prob(-1) gives -0.7. might be worth looking into.
def truncated_normal(t, mean=0.0, std=0.01):
torch.nn.init.normal_(t, mean=mean, std=std)
while True:
cond = torch.logical_or(t < mean - 2*std, t > mean + 2*std)
if not torch.sum(cond):
break
t = torch.where(cond, torch.nn.init.normal_(torch.ones(t.shape), mean=mean, std=std), t)
return t
I am using it like this m.weight.data = truncated_normal(m.weight.data) to initialize my weights. It is convenient in my case
Fills the input Tensor with values drawn from a truncated
normal distribution. The values are effectively drawn from the
normal distribution :math:\mathcal{N}(\text{mean}, \text{std}^2)
with values outside :math:[a, b] redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:a \leq \text{mean} \leq b.
It seems torch.nn.init.trunc_normal_ does not appear in the documentation of TORCH.NN.INIT, so I am a little confused whether it is stable version of this method?
The doc string seems to be properly defined here and you can access it also in IPython so I’m wondering if the docs generation is somehow missing it (@sachin_yadav also pointed this out already).
CC @albanD do you know where the docs generation might fail?
Currently, torch.nn.init.trunc_normal_ does not appear in the documentation of TORCH.NN.INIT . So, is there a timeline for its updates in the documentation?