Adaptive Computation Time help

Hi everyone,

I am working on implementing a system that uses Adaptive Computation Time as outlined here I am building a basic example where the input is a random number 1-9 and the output should always be 10. Instead of directly choosing the output, the network runs in a loop and each loop it adds 1 to the input. The network has a halting neuron that decides when to stop running. The hope is that the network learns to stop running when the input equals 10, at which point it will return the input.

The problem I am facing is that the network always chooses to do the maximum number of loops, regardless of the input. Since the output is a weighted sum of the inputs depending on the output of the halting neuron at that time step, it is able to “cheat” in a way and simply use a precise halting output at the last time step to get roughly close to 10 and decrease the loss. To stop the network from always using the max loops, I tried to implement a “ponder loss” which penalizes the network for each loop it uses. I am certian that I am doing this wrong, as I adjusted the weight on the ponder loss up very high believing the network would only do one loop, but it still does the max number of loops.

Here is the full implementation of the model:

class DynamicNet(nn.Module):
    def __init__(self, hiddenSize, epsilon, maxSteps=20):
        super(DynamicNet, self).__init__()
        self.haltingNeuron = nn.Sequential(nn.Linear(1, hiddenSize), nn.Linear(hiddenSize, 1))
        self.epsilon = epsilon
        self.maxSteps = maxSteps

    def forward(self, x):
        haltTotal = torch.Tensor([0]).to(device)
        computePenalty = torch.Tensor([0]).to(device)
        computePenalty.requires_grad = True
        currentOutput = x
        currentOutput.requires_grad = True
        finalOutput = torch.zeros(self.maxSteps, requires_grad=True).to(device)

        for step in range(self.maxSteps):
            currentOutput = currentOutput + 1

            haltingOutput = torch.sigmoid(self.haltingNeuron(currentOutput))
            computePenalty[0] = computePenalty[0] + (1 - haltingOutput)

            if haltTotal[0] > 1 - self.epsilon:
                finalOutput[step] = (currentOutput * ((1 - self.epsilon) - haltTotal[0])) # Finish final total with the remainder of halt total
                return(currentOutput * ((1 - self.epsilon) - haltTotal[0]), computePenalty[0] * 5, step)
                finalOutput[step] = (currentOutput * haltingOutput[0])

            haltTotal[0] = haltTotal[0] + haltingOutput

        return(currentOutput * ((1 - self.epsilon) - haltTotal[0]), computePenalty[0] * 5, self.maxSteps)

Please let me know if you need to provide additional information! Thanks in advance!