How is weight normalization calculated?
import torch, torch.nn as nn
lin = nn.Linear(3, 3, bias=False)
inp = torch.randn(3, 3)
lin = nn.utils.weight_norm(lin)
optimizer = torch.optim.SGD(lin.parameters(), lr=0.01)
list(lin.parameters())
[Parameter containing:
tensor([[0.6400],
[0.6961],
[0.6579]], requires_grad=True), Parameter containing:
tensor([[ 0.2672, 0.5720, 0.1048],
[-0.4538, 0.5243, -0.0612],
[-0.5537, 0.0813, -0.3459]], requires_grad=True)]
optimizer.zero_grad()
loss = lin(inp).sum()
loss.backward()
optimizer.step()
lin.weight_g
Parameter containing:
tensor([[0.6161],
[0.6924],
[0.6708]], requires_grad=True)
lin.weight_v
Parameter containing:
tensor([[ 0.2619, 0.5746, 0.1042],
[-0.4715, 0.5084, -0.0660],
[-0.5581, 0.0611, -0.3437]], requires_grad=True)
lin.weight
tensor([[ 0.2672, 0.5720, 0.1048],
[-0.4538, 0.5243, -0.0612],
[-0.5537, 0.0813, -0.3459]], grad_fn=)
I use this formula to calculate lin.weight from lin.weight_g and lin.weight_v, but it gives wrong lin.weight
lin.weight = (lin.weight_g/lin.weight_v.norm())*lin.weight_v