Custom LayerNorm vs PyTorch implementation

Hello,
I stumbled upon the implementation of LayerNorm which was based on ConvNeXt/models/convnext.py at main · facebookresearch/ConvNeXt · GitHub. I wanted to compare it to GroupNorm and the results are pretty weird. Here is snippet:

import torch
import torch.nn as nn

class LayerNormFunction(torch.autograd.Function):

    @staticmethod
    def forward(ctx, x, weight, bias, eps):
        ctx.eps = eps
        N, C, H, W = x.size()
        var, mu = torch.var_mean(x, dim=1, keepdim=True, unbiased=False)
        y = (x - mu) / torch.sqrt(var + eps)
        ctx.save_for_backward(y, var, weight)
        y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1)
        return y

class LayerNorm2d(nn.Module):

    def __init__(self, channels, eps=1e-5):
        super(LayerNorm2d, self).__init__()
        self.register_parameter('weight', nn.Parameter(torch.ones(channels)))
        self.register_parameter('bias', nn.Parameter(torch.zeros(channels)))
        self.eps = eps

    def forward(self, x):
        return LayerNormFunction.apply(x, self.weight, self.bias, self.eps)

torch.manual_seed(0)
x = torch.randn(1, 3, 5, 5)
g = nn.GroupNorm(1, 3)
l = nn.LayerNorm((3, 5, 5))
lc = LayerNorm2d(3)

y_g = g(x)
y_l = l(x)
y_lc = lc(x)
print(f"(y_g - y_lc).pow(2).sum().sqrt() = {(y_g - y_lc).pow(2).sum().sqrt()}")
print(f"(y_g - y_l).pow(2).sum().sqrt() = {(y_g - y_l).pow(2).sum().sqrt()}")
print(f"(y_l - y_lc).pow(2).sum().sqrt() = {(y_l - y_lc).pow(2).sum().sqrt()}")

Can someone explain why results of LayerNorm2d is different from PyTorchs’ norms?

1 Like