Layer norm will first compute the var and mean of its input, so why the following code shows that layer norm is faster than torch.var_mean
.
from functools import partial
import timeit
import torch
x = torch.randn((257, 252, 192),dtype=torch.float32)
ln = torch.nn.LayerNorm(192)
ln.eval()
with torch.no_grad():
var_mean_time = timeit.timeit(partial(torch.var_mean, input=x, dim=(2,)), number=100)
ln_time = timeit.timeit(partial(ln, input=x), number=100)
print(var_mean_time, ln_time) # 3.2 1.18