In case anyone needs a truncated normal distribution class, implementing the torch.distributions.Distribution interface, I put together one here: https://github.com/toshas/torch_truncnorm
Hello @anton, Thanks for posting your take on this. The class was very intuitive to use. But I found that when supplied a value outside of [a,b], log_prob does not give -inf. For example, a distribution with a=0, b=inf, loc=0, scale=1, log_prob(-1) gives -0.7. might be worth looking into.
I admit, this is not the most efficient way of doing it but it is re-drawing the values rather than clipping them as it is described in the tf specification (“values more than two standard deviations from the mean are discarded and re-drawn”)
def truncated_normal(t, mean=0.0, std=0.01):
torch.nn.init.normal_(t, mean=mean, std=std)
while True:
cond = torch.logical_or(t < mean - 2*std, t > mean + 2*std)
if not torch.sum(cond):
break
t = torch.where(cond, torch.nn.init.normal_(torch.ones(t.shape), mean=mean, std=std), t)
return t
I am using it like this m.weight.data = truncated_normal(m.weight.data)
to initialize my weights. It is convenient in my case
Use torch.nn.init.trunc_normal_
.
Description as given Here:
Fills the input Tensor with values drawn from a truncated
normal distribution. The values are effectively drawn from the
normal distribution :math:\mathcal{N}(\text{mean}, \text{std}^2)
with values outside :math:[a, b]
redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:a \leq \text{mean} \leq b
.
It seems torch.nn.init.trunc_normal_
does not appear in the documentation of TORCH.NN.INIT, so I am a little confused whether it is stable version of this method?
As of March 30, '22. torch.nn.init.trunc_normal_ is still not in docs. Is there any plan to note them down?
The doc string seems to be properly defined here and you can access it also in IPython so I’m wondering if the docs generation is somehow missing it (@sachin_yadav also pointed this out already).
CC @albanD do you know where the docs generation might fail?
It should be added to this file: pytorch/docs/source/nn.init.rst at 0765a804911673fb2d9694a76ba0196ea0eddec4 · pytorch/pytorch · GitHub
@quocdat32461997 if you want to send a PR fixing that, you can add me as a reviewer!
Hi, I sent a PR #76530 to fix this. Because this is my first pr to pytorch, any suggestions will be appreciated.
Looks good to me, thank you @baudzhou .
Currently, torch.nn.init.trunc_normal_
does not appear in the documentation of TORCH.NN.INIT . So, is there a timeline for its updates in the documentation?