I am attempting to replicate the exact implementation of GroupNorm, but upon testing, I’m off by approximately 1e-2
when randomly generating a tensor and feeding it through my implementation of GroupNorm and pytorch’s GroupNorm function.
I suspect that there is a running average implementation in GroupNorm that accounts for the difference? Here is my code:
(PS: Ignore the name of the class, it is not relevant to the implementation )
# This implementation of GroupNormalization comes from the original paper:
# Figure 3 in https://arxiv.org/pdf/1803.08494.pdf
import torch
import torch.nn as nn
class SubtractiveOnlyGroupNorm(nn.Module):
def __init__(self, num_groups, num_channels, affine=True, eps=1e-5, momentum=1e-4):
super().__init__()
self.num_groups = num_groups
self.num_channels = num_channels
self.eps = eps
self.affine = affine
self.running_mean = 0
self.running_var = 0
self.momentum = momentum
if self.affine:
self.weight = nn.Parameter(torch.ones(1, num_channels, 1, 1))
self.bias = nn.Parameter(torch.zeros(1, num_channels, 1, 1))
def forward(self, x):
# Reshape the input tensor so that the spatial dimensions and channels are grouped together
# We assume that the input has shape (batch_size, num_channels, height, width)
batch_size, num_channels, height, width = x.size()
x = x.view(batch_size, self.num_groups, num_channels // self.num_groups, height, width)
# Calculate the mean and variance for each group
# mean = x.mean(dim=(2, 3, 4), keepdim=True)
# var = x.var(dim=(2, 3, 4), keepdim=True)
mean = torch.mean(x, dim=(2,3,4), keepdim=True)
var = torch.var(x, dim=(2,3,4), unbiased=True, keepdim=True)
self.running_mean = self.momentum * self.running_mean + (1 - self.momentum) * mean
self.running_var = self.momentum * self.running_var + (1 - self.momentum) * var
# Apply normalization
x = (x - self.running_mean) / torch.sqrt(self.running_var + self.eps)
# Reshape the normalized tensor back to its original shape
x = x.view(batch_size, num_channels, height, width)
# Apply the learned weight and bias
if self.affine:
x = x * self.weight + self.bias
return x
# main function
if __name__ == "__main__":
x = torch.randn(32, 3, 28, 28) # batch_size, num_channels, height, width
norm1 = nn.GroupNorm(1, 3, affine=False, eps=1e-5)
norm2 = SubtractiveOnlyGroupNorm(1, 3, affine=False, eps=1e-5)
y1 = norm1(x)[3][0][0][1]
y2 = norm2(x)[3][0][0][1]
print(y1, y2)
print((norm1(x) == norm2(x)).all())