Positive Weights

Is there a way to ensure that the weights of the network stay in a positive range throughout training?

2 Likes

Did you find any solution for this question?

After you’ve updated the weights, add the following lines to your code:

for p in mdl.parameters():
    p.data.clamp_(0)

Example:


import torch
import torch.nn as nn

x = torch.arange(10).view(-1,1)
y = -3 * x

class NN_Linear_Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.lm = nn.Linear(1,1)
    def forward(self, X):
        out = self.lm(X)
        
        return out

mdl = NN_Linear_Model()

loss_fn = nn.MSELoss()
optimizer = torch.optim.SGD(mdl.parameters(), lr=0.0001)

for i in range(100):
    optimizer.zero_grad()
    y_pred = mdl(x)
    loss = loss_fn(y_pred, y)
    loss.backward()
    optimizer.step()
    for p in mdl.parameters():
        p.data.clamp_(0)
    
    if i % 10 == 0:
        print(f'loss is {loss} at iter {i}, weight: {list(mdl.parameters())[0].item()}')

list(mdl.parameters())

One caveat is your model may not converge to the optimal point as you are restricting where your parameters can go.

8 Likes

Does the loss function need to be changed to a specific one or something because when I applied the proposed solution, I got stuck in an accuracy of 10%
I am using CrossEntropy as loss function

Yes, softmax won’t converge when negative inputs are not expressible, while gradients push them there (as such use of clamp_ is invisible to autograd). Maybe softmax(x-10) will work, but it is still a hack. Not touching parameters of some final layers may work better.

PS: actually, softmax/CrossEntropy mainly exist to force positivity, if network outputs are always positive anyway, something like -CategoricalDistribution(probs=output).log_prob(target) may work as a loss.

Sorry I did not return here… I already solved it and it converged with softmax with no problem (just less accuracy, dropped from 95% to 82%)

The trick is to parameterize the weights by their logarithms. The log weights are allowed to vary freely among real numbers. An exponential map will convert the log weights to positive-definite weights before the weight is applied to the input data.

Example code:

import torch
import torch.nn as nn

class PositiveLinear(nn.Module):
    def __init__(self, in_features, out_features):
        super(PositiveLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.log_weight = nn.Parameter(torch.Tensor(out_features, in_features))
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.log_weight)

    def forward(self, input):
        return nn.functional.linear(input, self.log_weight.exp())
7 Likes

Is there any specific reason for using log. For example, why cannot be use square of the weight?

You can use an activation function which is constrained be positive, e.g., SoftPlus: Softplus — PyTorch 2.1 documentation

I agree that you should use softplus but other answers are outdated because they don’t account for the new awesome pytorch parametrization module! Take a look at this tutorial then reparametrize your layers weights with softplus.

Also here is some example code that does what I describe:

import torch
from torch import nn
import torch.nn.utils.parametrize as parametrize

class SoftplusParameterization(nn.Module):
    def forward(self, X):
        return nn.functional.softplus(X)

# Example registration of this parameterization transform
example_layer = nn.Conv2d(3,3,3)
parametrize.register_parametrization(example_layer, "weight", SoftplusParameterization())

assert torch.all(example_layer.weight>0) # now all > 0
print(example_layer.weight)