How to make the parameter of torch.nn.Threshold learnable?

I have just tried to do something similar: clip activations at some value, and make that value a learnable parameter. However, I’m not sure if this is working correctly, because the networks always tries to increase the upper bound beyond the optimal value which I found by trial and error:

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.act_max = nn.Parameter(torch.Tensor([0]), requires_grad=True)

        self.conv1 = nn.Conv2d(3, 32, kernel_size=5)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=5)
        self.pool = nn.MaxPool2d(2, 2)
        self.relu = nn.ReLU()
        self.linear = nn.Linear(64 * 5 * 5, 10)

    def forward(self, input):
        conv1 = self.conv1(input)
        pool1 = self.pool(conv1)
        relu1 = self.relu(pool1)

        relu1[relu1 > self.act_max] = self.act_max

        conv2 = self.conv2(relu1)
        pool2 = self.pool(conv2)
        relu2 = self.relu(pool2)
        relu2 = relu2.view(relu2.size(0), -1)
        return self.linear(relu2)


model = Net()
torch.nn.init.kaiming_normal_(model.parameters)
nn.init.constant(model.act_max, 1.0)
model = model.cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
for epoch in range(100):
    for i in range(1000):
        output = model(input)
        loss = nn.CrossEntropyLoss()(output, label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        model.act_max.data = model.act_max.data - 0.001 * model.act_max.grad.data

Also, I don’t understand why if I remove the last line, the value won’t update (even though the gradients are being calculated, and change every iteration).

First, here are the functions for learning maximums and minimums:

def ClampMin(x, val):
    """
    Clamps x to minimum value 'val'.
    val < 0.0
    """
    return x.clamp(max=0.0).sub(val).clamp(min=0.0).add(val) + x.clamp(min=0.0)

def ClampMax(x, val):
    """
    Clamps x to maximum value 'val'.
    val > 0.0
    """
    return x.clamp(min=0.0).add(val).clamp(max=0.0).sub(val) + x.clamp(max=0.0)

To demonstrate, I rewrote your code slightly:

class Net(nn.Module):
    def __init__(self):
        super().__init__()

        self.act_max = nn.Parameter(torch.Tensor([0.5]))

        self.conv1 = nn.Conv2d(3, 32, kernel_size=5)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=5)
        self.pool = nn.MaxPool2d(2, 2)
        self.relu = nn.ReLU()
        self.linear = nn.Linear(64 * 5 * 5, 10)

    def forward(self, input):
        conv1 = self.conv1(input)
        pool1 = self.pool(conv1)
        relu1 = self.relu(pool1)

        relu1 = ClampMax(relu1, self.act_max)

        conv2 = self.conv2(relu1)
        pool2 = self.pool(conv2)
        relu2 = self.relu(pool2)
        relu2 = relu2.view(relu2.size(0), -1)
        return self.linear(relu2)


model = Net()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

for epoch in range(1000):
    for _ in range(10):
        optimizer.zero_grad()
        # ------------------------------
        # some fake data and label
        inputs = torch.randn(1,3,32,32) * 0.1
        label = torch.randint(10, [1])
        # ------------------------------
        output = model(inputs)
        loss = nn.CrossEntropyLoss()(output, label)
        loss.backward()
        optimizer.step()
    #------------------------------
    print(epoch, loss.item(), model.act_max.data.item())

Your original idea (to write the maximum activation values into the output tensor of the MaxPool2d layer) won’t work since in-place updates to tensors are not completely supported by Autograd and it probably overwrites something (see: IN-PLACE OPERATIONS WITH AUTOGRAD).

Running this code should show that the act_max Parameter is being updated after each mini-batch.

I think there’s a typo somewhere in your functions, because, for example, if we enter ClampMax(10,5) the result will be -5.

There’s a bug for sure - sorry about that…

The corrected code:

def ClampMax(x, val):
    """
    Clamps x to val.
    val >= 0.0
    """
    return x.clamp(min=0.0).sub(val).clamp(max=0.0).add(val) + x.clamp(max=0.0)

A simple check:

x = torch.tensor([ 0.3889, -0.0575,  0.5435, -0.6713,  1.4574])
print(f"x :{x}")
print(ClampMin(x, torch.tensor([-0.5])))
print(ClampMax(x, torch.tensor([ 0.5])))

results in:

x :tensor([ 0.3889, -0.0575, 0.5435, -0.6713, 1.4574])
tensor([ 0.3889, -0.0575, 0.5435, -0.5000, 1.4574])
tensor([ 0.3889, -0.0575, 0.5000, -0.6713, 0.5000])

I hope it helps!

Thanks! I just tried 3 different ways to do this:

relu1[relu1 > self.act_max] = self.act_max
relu1 = torch.where(relu1 > self.act_max, self.act_max, relu1)
relu1 = torch.clamp(relu1,min=0.0).sub(self.act_max).clamp(max=0.0).add(self.act_max).add(relu1.clamp(max=0.0))

and all three work exactly the same. My original issue of act_max not updating was caused by something else (bad per-parameter optimizer settings). So it turns out it does not matter if it’s in-place or not, the gradients are exactly the same (but in-place version is slower).

In my network act_max gets to about 0.9 and stays there. However I’m not sure if these gradients are being computed correctly. I’m guessing they are somehow based on the gradients for the activations, but I don’t quite understand how. My initial guess was if the average of all activations is increasing, the upper bound should increase also, but I’m not sure what the relationship should be. In my network the activations keep growing after the threshold stabilizes. I wonder if we should look only at the activations which are currently above the upper bound, and check their average gradient: if these activations want to grow, then it might make sense to increase the threshold. However it’s not clear if this would actually decrease the loss. Also, perhaps the thresholding serves as a regularizer so we might want to use it even if it increases the training loss slightly.

Do you have any insights from your experiments? Have you used a different learning rate for the threshold?

I tested the in-place update and it seems that my assumption about that was indeed wrong.

After a few SGD iterations the gradient of act_max becomes very close to zero, but it looks like the value of act_max can also stabilise at negative values, causing the relu1 outputs to be all constant negative - perhaps an interesting side-effect that could be investigated further.

I don’t have any other insights since I must admit that I’ve abandoned that approach. I wasn’t getting the results I had hoped for and I now consider it to be one of my (many) failed experiments.

@hughperkins @hughperkins
Can someone please explain mathematically how we are computing the gradient for the threshold? :slight_smile: