Hi, I used the following two implementations. With Implementation 2, I am getting better accuracy.
But I am not clear of how nn.utils.weight_norm
will change the performance. The PyTorch documentation reads that nn.utils.weight_norm
is just used to decouple the norm vector and the angle. Then why is there difference in the numerical value?
Implementation 1
def __init__(self):
super(MyModel, self).__init__()
self.linear = nn.Linear(2, 2)
def forward(self, x):
weight = F.normalize(self.linear.weight)
out = torch.mm(x, weight.t()) + self.linear.bias
return out
Implementation 2
def __init__(self):
super(MyModel, self).__init__()
self.linear = nn.utils.weight_norm(nn.Linear(2, 2))
def forward(self, x):
self.linear.weight = F.normalize(self.linear.weight)
out = self.linear(x)
return out
Please let me know what is the right way of using L2-normalized weights for classification.
Thanks!