How to do weight normalization in last classification layer?

Hi @ptrblck, Thanks for your quick response!

Following is the code snippet that reproduces the error.

import torch
import torch.nn as nn
import torch.optim as optim


class LinearLayer(nn.Module):
    def __init__(self):
        super(LinearLayer, self).__init__()
        self.classifier = nn.Sequential(nn.Linear(10, 6))

    def forward(self, x):
        with torch.no_grad():                     ### 1
            self.classifier[0].weight.div_(torch.norm(self.classifier[0].weight, dim=1, keepdim=True))
        out = self.classifier(x)
        with torch.no_grad():                     ### 2
            self.classifier[0].weight.div_(torch.norm(self.classifier[0].weight, dim=1, keepdim=True))
        return out

linear = LinearLayer()
optimizer = optim.SGD([{'params': linear.classifier.parameters(), 'lr': 0.1},], 
                      momentum=0.9, weight_decay=5e-4)


for i in range(10):
    x = torch.randn(4, 10)
    targets = torch.LongTensor([1, 2, 3, 4])
    optimizer.zero_grad()
    out = linear(x)
    print(out.size())
    print(targets.size())
    loss = nn.CrossEntropyLoss()(out, targets)
    loss.backward()
    optimizer.step()    

Following is the error that this code will give.

Error: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [10, 6]], which is output 0 of TBackward, is at version 3; expected version 2 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

For my use case, I need to normalize the weights of the classification layer before (### 1) and after (### 2) forward pass. I need to do few more operations on the weights after the ###2.

The code will work fine if I remove the second normalization i.e. ###2.

Before I looked at this post, I had posted few alternatives for this normalization here. Please can you verify if my Implementation 1 is correct in that post?