How to Properly Normalize Weights During Training in PyTorch Without Bypassing Autograd?

I’m implementing a neural network in PyTorch and need to normalize the weights of certain layers during the forward pass. Specifically, I want to normalize the weights by their L2 norm for some layers. Here’s a simplified version of my code:

import torch
import torch.nn.functional as F

class MyModel(torch.nn.Module):
    def __init__(self, layers, activation_function):
        super(MyModel, self).__init__()
        self.layers = torch.nn.ModuleList(layers)
        self.act_fun = activation_function

    def forward(self, X):
        output = X
        for i, layer in enumerate(self.layers):
            if i > 0:
                # Normalize the weights
                layer.weight.data = F.normalize(layer.weight, p=2, dim=1)
            if i < len(self.layers) - 1:
                output = self.act_fun(layer(output))
            else:
                output = layer(output)
        return output.squeeze()

My concerns are:

  1. Autograd Compatibility: By directly modifying layer.weight.data, am I bypassing PyTorch’s autograd system? Will this prevent gradients from being computed correctly during backpropagation?
  2. Proper Gradient Updates: Will the weight normalization be accounted for when I call loss.backward(), or do I need to handle this differently to ensure correct gradient computation?
  3. Better Practices: Is there a recommended way to normalize layer weights during training in PyTorch that maintains compatibility with autograd and ensures proper gradient updates?

I’ve read that modifying .data directly can cause issues with gradient tracking, but I’m unsure how to implement weight normalization correctly in this context.

You could implement a parametrization as described in the Parametrization Tutorial.

Thanks,

But using a parametrization with a simple forward pass that returns F.normalize(self.weight) wouldn’t that be the same as using a model forward pass including

normalized_weight = F.normalize(layer.weight, p=2, dim=1)
output = F.linear(output, normalized_weight, layer.bias)

?