I’m currently working on recreating the input after LayerNorm. As far as I know, the mean and standard deviation for LayerNorm are fixed during the inference phase. Therefore, I thought I could extract these factors and recreate the original input from the LayerNorm output.
I have successfully extracted the weight and bias, which are not necessarily identical to the mean and standard deviation because LayerNorm has its own weight and bias parameters. My weight and bias parameters are fused from various factors, but they successfully recreate the original input from the LayerNorm output.
However, when I applied these extracted weight and bias parameters to another input tensor and expected LayerNorm to work in the same way as with the previous input, I obtained a completely different output. I assumed that LayerNorm calculated new mean and standard deviation values for the second input, causing the difference. But I’m puzzled as to why LayerNorm computed the mean and standard deviation for the second input; they should have remained fixed during inference.
import torch
from torch import nn
input_data = torch.randn(1, 577, 768)
input_remade = torch.randn(1, 577, 768)
class layer(nn.Module):
def __init__(self):
super().__init__()
self.norm1 = nn.LayerNorm(768, eps=1e-6)
#self.norm1 = nn.LayerNorm(768, eps=0)
def forward(self, x):
x = self.norm1(x)
return x
layer = layer().eval()
with torch.inference_mode():
out = layer(input_data)
w = torch.zeros(len(out[0, :, 0]))
b = torch.zeros(len(out[0, :, 0]))
for i in range(len(out[0, :, 0])):
w[i] = (input_data[0, i, 0] - input_data[0, i, 10]) / (out[0, i, 0] - out[0, i, 10])
b[i] = (input_data[0, i, 0] * out[0, i, 10] - input_data[0, i, 10] * out[0, i, 0]) / (out[0, i, 10] - out[0, i, 0])
for i1 in range(len(input_remade[0, :, 0])):
input_remade[0, i1, :] = out[0, i1, :] * w[i1] + b[i1]
print(torch.sum(input_remade - input_data))
input_data2 = torch.randn(1, 577, 768)
input_remade2 = torch.randn(1, 577, 768)
with torch.inference_mode():
out2 = layer(input_data2)
for i1 in range(len(input_remade2[0, :, 0])):
input_remade2[0, i1, :] = out2[0, i1, :] * w[i1] + b[i1]
print(torch.sum(input_remade2 - input_data2))
w1 = torch.zeros(len(out2[0, :, 0]))
b1 = torch.zeros(len(out2[0, :, 0]))
for i in range(len(out2[0, :, 0])):
w1[i] = (input_data2[0, i, 0] - input_data2[0, i, 10]) / (out2[0, i, 0] - out2[0, i, 10])
b1[i] = (input_data2[0, i, 0] * out2[0, i, 10] - input_data2[0, i, 10] * out2[0, i, 0]) / (out2[0, i, 10] - out2[0, i, 0])
for i1 in range(len(input_remade2[0, :, 0])):
input_remade2[0, i1, :] = out2[0, i1, :] * w1[i1] + b1[i1]
print(torch.sum(input_remade2 - input_data2))
tensor(-0.0061)
tensor(1280.9966)
tensor(0.0014)