Convolution operation with L2 normalized weights

Hi all,
Is there a way normalize (L2) the weights of a convolution kernel before performing the convolution?
For a fully connected layer, I’d go about doing something like:

# __init__()
weights = nn.Parameter(torch.Tensor(in_size, out_size))

# forward
forward(x):
    w = F.normalize(weights)
    out = torch.mm(x, w)
    return out

But I’m not exactly sure how this would work on convolutions, even if I’m using 1x1 convolution kernels.

Any ideas?

You could use the same approach, but call F.conv2d instead of torch.mm to perform a convolution.
Also, the input and weight shapes would be different, since F.conv2d expects an input of [batch_size, channels, height, width] and a weight of [out_channels, in_channels, height, width].

1 Like

I had originally implemented it by dividing the output by the norm of the convolution weights. However, this caused some NaN issues since torch.norm() may return a almost-zero value.

I’ve implemented a new method with your suggestion, and it seems to be working well now. Thanks!

Hi @ptrblck, what if instead of using nn.Parameter in the below code in the init() function, I use nn.utils.weight_norm?
Will that first reparameterize the weights then followed by L2-normalization before each forward pass?

init()

weights = nn.Parameter(torch.Tensor(in_size, out_size))

forward

forward(x):
w = F.normalize(weights)
out = torch.mm(x, w)
return out

Looking for help with PyTorch. Thanks!

weight_norm is applied on modules not parameters, so it won’t work in your code snippet.
Internally, it will register new parameters as seen here so you could try to adapt this code to your use case and manipulate your parameter manually.

1 Like

Hi @ptrblck, Thanks for your response. I used the following two implementations. With Implementation 2, I got better results on accuracy.

But I am not clear of how nn.utils.weight_norm will change the performance. The PyTorch documentation reads that nn.utils.weight_norm is just used to decouple the norm vector and the angle. Then why is there difference in the numerical value?

Implementation 1

def __init__(self):
    super(MyModel, self).__init__()
    self.linear = nn.Linear(2, 2)

def forward(self, x):
    weight = F.normalize(self.linear.weight)
    out = torch.mm(x, weight.t()) + self.linear.bias
    return out

Implementation 2

def __init__(self):
    super(MyModel, self).__init__()
    self.linear = nn.utils.weight_norm(nn.Linear(2, 2))

def forward(self, x):
    self.linear.weight = F.normalize(self.linear.weight)
    out = self.linear(x)
    return out