Reproducing GroupNorm (Running Mean & Var)

I am attempting to replicate the exact implementation of GroupNorm, but upon testing, I’m off by approximately 1e-2 when randomly generating a tensor and feeding it through my implementation of GroupNorm and pytorch’s GroupNorm function.

I suspect that there is a running average implementation in GroupNorm that accounts for the difference? Here is my code:

(PS: Ignore the name of the class, it is not relevant to the implementation )

# This implementation of GroupNormalization comes from the original paper:
# Figure 3 in https://arxiv.org/pdf/1803.08494.pdf

import torch
import torch.nn as nn

class SubtractiveOnlyGroupNorm(nn.Module):
    def __init__(self, num_groups, num_channels, affine=True, eps=1e-5, momentum=1e-4):
        super().__init__()
        self.num_groups = num_groups
        self.num_channels = num_channels
        self.eps = eps
        self.affine = affine
        self.running_mean = 0
        self.running_var = 0
        self.momentum = momentum
        if self.affine:
            self.weight = nn.Parameter(torch.ones(1, num_channels, 1, 1))
            self.bias = nn.Parameter(torch.zeros(1, num_channels, 1, 1))

    def forward(self, x):
        # Reshape the input tensor so that the spatial dimensions and channels are grouped together
        # We assume that the input has shape (batch_size, num_channels, height, width)
        batch_size, num_channels, height, width = x.size()
        x = x.view(batch_size, self.num_groups, num_channels // self.num_groups, height, width)
        
        # Calculate the mean and variance for each group
        # mean = x.mean(dim=(2, 3, 4), keepdim=True)
        # var = x.var(dim=(2, 3, 4), keepdim=True)
        mean = torch.mean(x, dim=(2,3,4), keepdim=True)
        var = torch.var(x, dim=(2,3,4), unbiased=True, keepdim=True)

        self.running_mean =  self.momentum * self.running_mean + (1 - self.momentum) * mean
        self.running_var = self.momentum * self.running_var + (1 - self.momentum) * var

        # Apply normalization
        x = (x - self.running_mean) / torch.sqrt(self.running_var + self.eps)
        
        # Reshape the normalized tensor back to its original shape
        x = x.view(batch_size, num_channels, height, width)
        
        # Apply the learned weight and bias
        if self.affine:
            x = x * self.weight + self.bias
        
        return x


# main function
if __name__ == "__main__":
    x = torch.randn(32, 3, 28, 28) # batch_size, num_channels, height, width
    norm1 = nn.GroupNorm(1, 3, affine=False, eps=1e-5)
    norm2 = SubtractiveOnlyGroupNorm(1, 3, affine=False, eps=1e-5)
    y1 = norm1(x)[3][0][0][1]
    y2 = norm2(x)[3][0][0][1]
    print(y1, y2)
    print((norm1(x) == norm2(x)).all())

From the docs:

The standard-deviation is calculated via the biased estimator, equivalent to torch.var(input, unbiased=False).

This layer uses statistics computed from input data in both training and evaluation modes.

After fixing this the error matches the expected precision limit of float32:

class SubtractiveOnlyGroupNorm(nn.Module):
    def __init__(self, num_groups, num_channels, affine=True, eps=1e-5, momentum=1e-4):
        super().__init__()
        self.num_groups = num_groups
        self.num_channels = num_channels
        self.eps = eps
        self.affine = affine
        self.running_mean = 0
        self.running_var = 0
        self.momentum = momentum
        if self.affine:
            self.weight = nn.Parameter(torch.ones(1, num_channels, 1, 1))
            self.bias = nn.Parameter(torch.zeros(1, num_channels, 1, 1))

    def forward(self, x):
        # Reshape the input tensor so that the spatial dimensions and channels are grouped together
        # We assume that the input has shape (batch_size, num_channels, height, width)
        batch_size, num_channels, height, width = x.size()
        x = x.view(batch_size, self.num_groups, num_channels // self.num_groups, height, width)

        mean = torch.mean(x, dim=(2,3,4), keepdim=True)
        var = torch.var(x, dim=(2,3,4), unbiased=False, keepdim=True)
        # Apply normalization
        x = (x - mean) / torch.sqrt(var + self.eps)
        
        # Reshape the normalized tensor back to its original shape
        x = x.view(batch_size, num_channels, height, width)
        
        # Apply the learned weight and bias
        if self.affine:
            x = x * self.weight + self.bias
        
        return x


# main function
if __name__ == "__main__":
    x = torch.randn(32, 3, 28, 28) # batch_size, num_channels, height, width
    norm1 = nn.GroupNorm(1, 3, affine=False, eps=1e-5)
    norm2 = SubtractiveOnlyGroupNorm(1, 3, affine=False, eps=1e-5)
    y1 = norm1(x)[3][0][0][1]
    y2 = norm2(x)[3][0][0][1]
    print(y1, y2)
    # tensor(1.6263) tensor(1.6263)
    print((norm1(x) - norm2(x)).abs().max())
    # tensor(4.7684e-07)

Thank you so much, it sounds like that did the trick!

I have one more question:

  1. Can you confirm that there is no running average and var computation in the Pytorch groupnorm implementation?

Yes, this is how I understand:

This layer uses statistics computed from input data in both training and evaluation modes.

In case anyone reading this thread is interested in the GroupNorm code, I made a small Github repo.