How to scale weights during training?

I’d like to train a convnet where each layer weights are divided by the maximum weight in that layer, at the start of every forward pass. So the range of the weights would always be [-1, 1].

I tried doing it like this:


class TestConv2d(nn.Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size=5, bias=False):
        super(TestConv2d, self).__init__(in_channels, out_channels, kernel_size, bias=bias)

    def forward(self, inputs):
        return F.conv2d(inputs, self.weight / torch.max(torch.abs(self.weight)), self.bias)


class TestLinear(nn.Linear):
    def __init__(self, in_features, out_features, bias=False):
        super(TestLinear, self).__init__(in_features, out_features, bias=bias)

    def forward(self, inputs):
        return F.linear(inputs, self.weight / torch.max(torch.abs(self.weight)), self.bias)


class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.conv1 = TestConv2d(3, 32, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = TestConv2d(32, 64, 5)
        self.flatten = nn.Flatten(start_dim=1)
        self.fc1 = TestLinear(64 * 5 * 5, 390)
        self.fc2 = TestLinear(256, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.flatten(x)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

The rest of the code is a standard training loop, and it works as expected on CIFAR-10 without weights scaling.

With weight scaling, however, there is some weird behavior. I tried training it with SGD and AdamW, and AdamW works much better. I had to reduce the initial learning rate significantly (compared to no scaling scanario), especially with SGD. However, I’m still not able to reach the same accuracy as without weight scaling (~2% drop with AdamW, ~10% drop with SGD).

Questions:

  1. What is happening to the weight gradients when I’m modifying weights like this?
  2. What would be a correct way to implement weight scaling?
  3. Why is AdamW works so much better than SGD in this case?

First of all, it is very interesting to apply layers with layer scaling. I tried to understand how the deep learning model works with it, but I could not figure out it.

I have some opinion about your questions;
1- Computional graph of your model in the following pic. It is shown how appliying gradinets.

2- I have no idea.

3- Did you use L2 regularization with SGD? AdamW applies built-in L2 regularization.

Thank you!

  1. How did you make this diagram?

  2. From the graph, it seems like the weights receive two sets of gradients: one normal, and one that goes through max/abs ops, so I’m guessing it gets applied only to the largest weight? Very interesting indeed.

  3. I did use L2 regularization for both SGD and AdamW. I tried different values, and it doesn’t seem to be very sensitive to it, but some values are better than others.

This reminds me of Normalized Direction-preserving Adam. Regarding your approach, you should probably use vector maximums instead of a single scalar divisor.
Also, nn.utils.weight_norm may be relevant.

Thank you! What do you mean “vector maximums”?

Looking at the computational graph, it seems like the largest weight in each layer might be receiving two values of gradients, one normal, and one coming through the path with max and abs ops. Any idea how to verify that?

torch.max(weights, dim=1,keepdim=True)[0]

Sorry, can you please explain why this is better?

Can you explain why is matrix max norm better that vector (L2?) norms? My understanding is that using vector norms (either rowwise or columnwise) for normalization don’t “tie” features together, so outputs are more expressive/diversified, with balanced variance.

Are you talking about channel range equalization, i.e. dividing by the channel max instead of the layer max? That’s an interesting idea, but do we actually want to “untie” the features captured by different channels? If a channel is “dead” with weights very close to zero, do we want to rescale it to be as strong as other, useful channels? Simply scaling down all channels by the same values preserves the relative channel importance information. I wonder if batch normalization does some form of channel equalization, but I believe it still allows some channels to go “dead”.

Anyway, I just tried it:

F.conv2d(inputs, self.weight / torch.abs(self.weight).max(dim=1, keepdim=True)[0].max(dim=2, keepdim=True)[0].max(dim=3, keepdim=True)[0])
F.linear(inputs, self.weight / torch.abs(self.weight).max(dim=1, keepdim=True)[0])

From what I see so far after adjusting learning rate and weight decay it does not train as well as dividing by the layer max weight (~3% drop) while I was able to reach almost the baseline accuracy with layer max scaling (~0.5% drop). The training with per channel scaling is not as stable as layer wise scaling, and is more sensitive to the hyperparam choices. It’s possible that with more effort I will close the gap between these two scaling methods. Though I still don’t quite understand what happens to gradients, the computational graph does not make much sense to me.

I used pytorchviz for diagram. (https://github.com/szagoruyko/pytorchviz)

All weight normalization variants restrict magnitudes, so there should be some layer that can rescale features freely.

It is an intermediate representation, so there is no value in “dead” channels. On the contrary, you get problems with vanishing gradients in multi-layer pipelines. Indeed, batch normalization helps to avoid “dead” channels (affine transformation it does is kinda auxiliary, I think).

You have a special cell in a weight matrix (located at argmax()) that all channels use to try to adjust outputs. Row that contains this cell thus produces a feature with different gradient magnitude and/or variance. So, SGD suffers, while Adam adjusts to mitigate this.

The solution to your original question is very straightforward. The key idea is to use torch.nn.Parameter as shown below.

class MyCNN(torch.nn.Module):
    def __init__(self):
        super(MyCNN, self).__init__()
        
        conv1 = nn.Conv2d(..specify kernel size, stride, etc..)
        self.conv1_w = torch.nn.Parameter(conv1.weight)
        self.conv1_b = torch.nn.Parameter(conv1.bias)
        
        # have several more like this.
        
    def forward(self, x):
        
        x = torch.nn.functional.conv2d(x, self.conv1_w, self.conv1_b, stride=1, padding=0, dilation=1, groups=1)
        
        self.conv1_w = self.conv1_b/torch.max(self.conv1_w)
        self.conv1_b = self.conv1_b/torch.max(self.conv1_b)
        
        return x