Weights not updating with custom normalization layer

Hi,
I’m trying to implement a custom layer, which mimics nn.GroupNorm, on Pytorch 1.9.0.

import torch
import torch.nn as nn
from torch.functional import Tensor

class CustomGroupNorm(nn.Module):
    """
    Custom implementation of PyTorch's `torch.nn.GroupNorm` layer.
    """

    def __init__(self, num_groups: int, num_channels: int, eps: float = 1e-5):
        super(GroupNorm, self).__init__()
        self.num_groups = num_groups
        self.num_channels = num_channels
        self.weight = nn.Parameter(torch.ones(1, self.num_channels, 1, 1))
        self.bias = nn.Parameter(torch.zeros(1, self.num_channels, 1, 1))
        self.eps = eps
        if self.num_channels % self.num_groups != 0 or self.num_channels < self.num_groups:
            raise ValueError('Number of channel needs to be visible by number of groups')        

    def forward(self, x: Tensor):
        N, C, H, W = x.size()
        x = x.view(N, self.num_groups, -1)
        x = nn.functional.layer_norm(input=x, normalized_shape=(x.size(-1), ), weight=None, bias=None, eps=self.eps)
        x = x.view(N, C, H, W)

        return x * self.weight + self.bias
  • Issue: When training with this layer, self.weights and self.bias do not get updated and the overall loss stays the same.
  • Intention: Prior to this, I had succeeded training this layer with simple mean/var calculations, but now I have replace only that calculation part with nn.functional.layer_norm(weight=None, bias=None) for simplification.

I expected self.weight and self.bias would still be updated via autograd since they are declared as nn.Parameters, but they don’t. Any help or pointer would be greatly appreciated. Thank you.

@Ju_Jeremy I replicated the same code with random inputs and the loss is changing and the weight and bias are getting updated
What is the loss function that you are using?

class CustomGroupNorm(nn.Module):
    """
    Custom implementation of PyTorch's `torch.nn.GroupNorm` layer.
    """

    def __init__(self, num_groups: int, num_channels: int, eps: float = 1e-5):
        super(CustomGroupNorm, self).__init__()
        self.num_groups = num_groups
        self.num_channels = num_channels
        self.weight = nn.Parameter(torch.ones(1, self.num_channels, 1, 1))
        self.bias = nn.Parameter(torch.zeros(1, self.num_channels, 1, 1))
        self.eps = eps
        if self.num_channels % self.num_groups != 0 or self.num_channels < self.num_groups:
            raise ValueError('Number of channel needs to be visible by number of groups')        

    def forward(self, x):
        print(x.size())
        N, C, H, W = x.size()
        x = x.view(N, self.num_groups, -1)
        x = nn.functional.layer_norm(input=x, normalized_shape=(x.size(-1), ), weight=None, bias=None, eps=self.eps)
        x = x.view(N, C, H, W)
        return x * self.weight + self.bias
    
N = 5
C = 3
H = 28
W = 28
model = CustomGroupNorm(num_groups=C, num_channels=C)
X = torch.rand(N, C, H, W)


loss = torch.nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

for cur_epoch in range(10):
    model.zero_grad()
    output = model(X)
    final_loss = loss(output, X)
    final_loss.backward()
    optimizer.step()
    print("epoch {0} Loss {1}".format(cur_epoch, final_loss))
    print(model.bias.data[0][0])
    print(model.weight.data[0][0])