Implementing Truncated Normal Initializer

Thanks @ruotianluo.
I confirmed that your function is equivalent to scipy truncnorm.

if anyone wants reproduce:

import torch
from scipy.stats import truncnorm
import matplotlib.pyplot as plt

def truncated_normal_(tensor, mean=0, std=1):
    size = tensor.shape
    tmp = tensor.new_empty(size + (4,)).normal_()
    valid = (tmp < 2) & (tmp > -2)
    ind = valid.max(-1, keepdim=True)[1]
    tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1))
    tensor.data.mul_(std).add_(mean)
    return tensor



fig, ax = plt.subplots(1, 1)


def test_truncnorm():
    a, b = -2, 2
    size = 1000000
    r = truncnorm.rvs(a, b, size=size)
    ax.hist(r, density=True, histtype='stepfilled', alpha=0.2, bins=50)

    tensor = torch.zeros(size)
    utils.truncated_normal_(tensor)
    r = tensor.numpy()

    ax.hist(r, density=True, histtype='stepfilled', alpha=0.2, bins=50)
    ax.legend(loc='best', frameon=False)
    plt.show()


if __name__ == '__main__':
    test_truncnorm()