How to do weight normalization in last classification layer?

This code should work:

        self.pred = torch.nn.Linear(2, 10, bias=False)

        with torch.no_grad():
            self.pred.weight.div_(torch.norm(self.pred.weight, dim=1, keepdim=True))
        ...
  1. You have to flatten the activation somehow, so .view would be the easiest way.
    Alternatively, you could write a Flatten module, initialize if in your model’s __init__, and call it in your forward. I’m not sure, if I understood your question correctly, so let me know if I missed something.
  2. You can just access it as you’ve already done normalizing the weights: print(self.pred.weight).
  3. You could norm the activation in the forward method (similar to your weight norm code).
5 Likes