Weight norm calculation

How is weight normalization calculated?

import torch, torch.nn as nn
lin = nn.Linear(3, 3, bias=False)
inp = torch.randn(3, 3)
lin = nn.utils.weight_norm(lin)
optimizer = torch.optim.SGD(lin.parameters(), lr=0.01)
list(lin.parameters())

[Parameter containing:
tensor([[0.6400],
[0.6961],
[0.6579]], requires_grad=True), Parameter containing:
tensor([[ 0.2672, 0.5720, 0.1048],
[-0.4538, 0.5243, -0.0612],
[-0.5537, 0.0813, -0.3459]], requires_grad=True)]

optimizer.zero_grad()
loss = lin(inp).sum()
loss.backward()
optimizer.step()
lin.weight_g

Parameter containing:
tensor([[0.6161],
[0.6924],
[0.6708]], requires_grad=True)

lin.weight_v

Parameter containing:
tensor([[ 0.2619, 0.5746, 0.1042],
[-0.4715, 0.5084, -0.0660],
[-0.5581, 0.0611, -0.3437]], requires_grad=True)

lin.weight

tensor([[ 0.2672, 0.5720, 0.1048],
[-0.4538, 0.5243, -0.0612],
[-0.5537, 0.0813, -0.3459]], grad_fn=)

I use this formula to calculate lin.weight from lin.weight_g and lin.weight_v, but it gives wrong lin.weight

lin.weight = (lin.weight_g/lin.weight_v.norm())*lin.weight_v

If you check the doc, the default for dim is 0, maybe you want to set it to None here?

1 Like

Yes, it works for dim=None, in weight_norm, also, for default dim=0,
I used this formula,

lin.weight_g*(lin.weight_v/lin.weight_v.norm(dim=1, keepdim=True))

or even this works

lin.weight_g*(lin.weight_v/torch.norm_except_dim(lin.weight_v, 2, dim=0))

I found out that incase of gradient descent here, weight_g, weight_v get updated, and after optimizer.step(), weight stays as it is. (so maybe, last training batch will have no impact on weights?)
weight is calculated from these updated weight_g, weight_v, before forward.

also, the first lin.weight_g is calculated as,

torch.norm_except_dim(lin.weight, 2, dim=0)

which leads to an output, whose shape is not changed for the specified dimension.
also, the first lin.weight_v value is same as lin.weight value.

I think main thing here is,

  • torch.norm(dim=1) gives output, whose shape is changed only in the specified dimension, and if we keepdim=True, then shape of that dimension becomes 1.

  • torch.norm_except_dim(dim=1) gives output, whose shape is not changed for the specified dimension, and rest all dimensions shape becomes 1.

  • and we want lin.weight_g, lin.weight_v.norm shape to be same, therefore, we use, lin.weight_v.norm(dim=1, keepdim=True) or torch.norm_except_dim(lin.weight_v, 2, dim=0)

  • if I set dim=1 in weight_norm, then lin.weight_g would be calculated by torch.norm_except_dim(lin.weight, 2, dim=1), and lin.weight would consider lin.weight_v.norm(dim=0, keepdim=True) or torch.norm_except_dim(lin.weight_v, 2, dim=1) to match with shape of lin.weight_g

Original Weights can simply be calculated by removing the weight norm from the currrent weight norm layer, using the function “torch.nn.utils.remove_weight_norm(layer)”

Please run the below code snippet once and check output values :

l = nn.Linear(2, 4)
print(l.state_dict())
w1 = weight_norm(l, name="weight")
print(w1.state_dict())
remove_weight_norm(w1)
print(w1.state_dict())