Layer Normalization gradients

I am attempting to create my own custom Layer Normalization layer, and I intend on my implementation working identically to PyTorch’s nn.layerNorm. However, when I try to recreate the layer, I always get slightly different gradients for the input. This is the script I am using:

import torch
import torch.nn as nn

class CustomLayerNorm(nn.Module):
    def __init__(self, normalized_shape, eps=1e-5):
        super(CustomLayerNorm, self).__init__()
        # Define learnable parameters gamma (scale) and beta (shift)
        self.gamma = nn.Parameter(torch.ones(normalized_shape))
        self.beta = nn.Parameter(torch.zeros(normalized_shape))
        self.eps = eps

    def forward(self, x):
        # Calculate mean and variance along dimension 2
        mean = x.mean(dim=2, keepdim=True)
        variance = x.var(dim=2, unbiased=False, keepdim=True)
        
        # Normalize the input
        x_normalized = (x - mean) / torch.sqrt(variance + self.eps)
        
        # Scale and shift
        out = self.gamma * x_normalized + self.beta
        return out

# Initialize custom layer norm and PyTorch's layer norm
custom_layer_norm = CustomLayerNorm(normalized_shape=(5,))
pytorch_layer_norm = nn.LayerNorm(normalized_shape=(5,))

# Set PyTorch's gamma and beta to match our custom layer norm

# Define the input tensor
x = torch.tensor(
    [[[76.1738, 77.1738, 76.1738, 77.1738, 76.1738],
      [77.0152, 76.7141, 76.1989, 77.1735, 76.1744],
      [77.0831, 75.7576, 76.2240, 77.1725, 76.1750],
      [76.3149, 75.1838, 76.2491, 77.1709, 76.1757],
      [75.4170, 75.5201, 76.2741, 77.1687, 76.1763]]], requires_grad=True
)

# Forward pass for both layers
custom_output = custom_layer_norm(x)
pytorch_output = pytorch_layer_norm(x)

# Compare the output of the custom layer to the PyTorch layer
print("Custom LayerNorm Output:\n", custom_output)
print("PyTorch LayerNorm Output:\n", pytorch_output)

# Backward pass
custom_output.sum().backward()
print("\nInput Gradients - Custom LayerNorm:\n", x.grad)

pytorch_output.sum().backward()

# Compare gradients

print("Input Gradients - PyTorch LayerNorm:\n", x.grad)

Am I doing this wrong? Why don’t the input gradients match?