Thanks @ruotianluo.
I confirmed that your function is equivalent to scipy truncnorm.
if anyone wants reproduce:
import torch
from scipy.stats import truncnorm
import matplotlib.pyplot as plt
def truncated_normal_(tensor, mean=0, std=1):
size = tensor.shape
tmp = tensor.new_empty(size + (4,)).normal_()
valid = (tmp < 2) & (tmp > -2)
ind = valid.max(-1, keepdim=True)[1]
tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1))
tensor.data.mul_(std).add_(mean)
return tensor
fig, ax = plt.subplots(1, 1)
def test_truncnorm():
a, b = -2, 2
size = 1000000
r = truncnorm.rvs(a, b, size=size)
ax.hist(r, density=True, histtype='stepfilled', alpha=0.2, bins=50)
tensor = torch.zeros(size)
utils.truncated_normal_(tensor)
r = tensor.numpy()
ax.hist(r, density=True, histtype='stepfilled', alpha=0.2, bins=50)
ax.legend(loc='best', frameon=False)
plt.show()
if __name__ == '__main__':
test_truncnorm()