Implementation of layernorm, precision is low

I asked about the implementation of layernorm in this post

I implemented it in both numpy and pytorch. It seems weird to me that the same implementation differs a lot in precision. Here’s the torch.nn layernorm:

tolerance_1 = 1e-6
tolerance_2 = 1e-3


y = torch.randn(50,768)
lnorm = torch.nn.LayerNorm(y.shape[-1])

#torch.nn layernorm output
layernorm_output = lnorm(y)

Here’s my implementation in pytorch

#my implementation of LayerNorm in pytorch
mean = torch.mean(y,axis=1)
var = torch.var(y,axis=1)
div = torch.sqrt(var+lnorm.eps)
stnd = (y - mean.unsqueeze(-1))/div.unsqueeze(-1)
my_output = stnd*lnorm.weight +lnorm.bias

Compared to the pytorch.nn.LayerNorm:

print("Implementation in pytorch, Precision {}:".format(tolerance_2))
print(torch.isclose(layernorm_output,my_output,tolerance_2))
print("Implementation in pytorch, Precision {}:".format(tolerance_1))
print(torch.isclose(layernorm_output,my_output,tolerance_1))

Here’s the result:

Implementation in pytorch, Precision 0.001:
tensor([[True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        ...,
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True]])
Implementation in pytorch, Precision 1e-06:
tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False]])

As you can see, the precision is only up to 1e-3.

However, here’s my implementation in numpy:

#my implementation of LayerNorm in Numpy
y_numpy = y.detach().numpy()
mean = np.mean(y_numpy,axis=1)
var = np.var(y_numpy,axis=1)
div = np.sqrt(var+lnorm.eps)
stnd = (y_numpy - mean.reshape(-1,1))/div.reshape(-1,1)
my_output_np = stnd*lnorm.weight.detach().numpy() +lnorm.bias.detach().numpy()

print("Implementation in numpy, Precision {}:".format(tolerance_1))
print(np.isclose(layernorm_output.detach().numpy(),my_output_np,tolerance_1))

The precision seems much higher, (1e-6):

Implementation in numpy, Precision 1e-06:
[[ True  True  True ...  True  True  True]
 [ True  True  True ...  True  True  True]
 [ True  True  True ...  True  True  True]
 ...
 [ True  True  True ...  True  True  True]
 [ True  True  True ...  True  True  True]
 [ True  True  True ...  True  True  True]]

Is this a expected behavior?

  1. use atol instead of rtol positional argument
  2. torch.var does the Bessel’s correction by default, unlike numpy & layernorm