Output always the same

Hey.
So I made a neural net that is supposed to play a board game, and for every game state it outputs the probability of winning at that point.
When I’m training the net, I have it occasionally print it’s output, target, loss, and accuracy. Here, the output varies.
However, when I just have the net predict a singular outcome or have it play the game, it always produces the same exact output.
I don’t understand why, it’s capable of producing different outputs for different inputs from the training set, and even when I have it just produce an output from one training example (instead of a batch) it again reverts to that same output.
Why is this happening?
If you need me to explain my net more in-depth or post some code, let me know.
Any help is appreciated, thanks :pray:

EDIT
Here’s the definition of my network:

class DiceProbabilisticNN(nn.Module):

boardState = [[0 for i in range(13)] for j in range(11271)]
dicePairings = [[0 for i in range(5)] for j in range(8301)]
fourDice = [[0 for i in range(5)] for j in range(3757)]

def __init__(self, inputSize, outputSize):
    super(DiceProbabilisticNN, self).__init__()
    self.numTurns = self.readFile()

    self.lin1 = nn.Linear(inputSize, 244)
    self.leak1 = nn.LeakyReLU()
    self.relu1 = nn.ReLU()
    self.batch1 = nn.BatchNorm1d(244)
    self.lin2 = nn.Linear(244, outputSize)
    self.relu2 = nn.ReLU()
    self.batch2 = nn.BatchNorm1d(outputSize)

def forward(self, input):
    output = self.lin1(input)
    output = self.leak1(output)
    output = self.batch1(output)
    output = self.lin2(output)
    output = self.relu2(output)
    output = self.batch2(output)
    return output

The input format is a 1x122 vector of bits, this describes three game states (current position, position at start of turn, opponent’s position) and two pairs of dice.
The actual input is a batch of these vectors over five random turns (we have a data file with something like 600 recorded turns).
Here’s the training and whatnot:

diceNN = DiceProbabilisticNN(122, 1)
diceOptimizer = optim.SGD(diceNN.parameters(), lr=0.001)
diceLossFunc = nn.MSELoss()

def train(input, target):
    diceOptimizer.zero_grad()
    output = diceNN(input)
    loss = diceLossFunc(output, target)
    loss.backward()
    diceOptimizer.step()
    return output, loss

allLosses = []
for i in range(10000):
    diceInput, diceTarget = diceNN.getRandomTrainingExample()
    diceOutput, diceLoss = train(0, diceInput, diceTarget)
    diceAccuracy = (1-((diceOutput - diceTarget)**2))*100

    if i%1000 == 0:
        allLosses.append(diceLoss.data[0])
        print("Progress: %s percent" % int((i/100)))
    if i == 0:
        print("Output: %s, Target: %s, Loss: %s, Accuracy: %s" % (diceOutput.data[0][0], float(diceTarget.data[0]), float(diceLoss.data[0]), diceAccuracy.data[0][0]))
    elif i == 4999:
        print("Output: %s, Target: %s, Loss: %s, Accuracy: %s" % (diceOutput.data[0][0], float(diceTarget.data[0]), float(diceLoss.data[0]), diceAccuracy.data[0][0]))
    elif i == 9999:
        print("Output: %s, Target: %s, Loss: %s, Accuracy: %s" % (diceOutput.data[0][0], float(diceTarget.data[0]), float(diceLoss.data[0]), diceAccuracy.data[0][0]))
    torch.save(diceNN.state_dict(), "./DiceNet.pth")

It’s a bit sloppy right now, I know, and I have omitted a bunch of the complicated input-constructing and file-reading methods, I am 100% certain that they work properly.
When I run the training cycle, I get this:

Progress: 0 percent
Output: -0.7317960858345032, Target: 0.2633669972419739, Loss: 0.7930728197097778, Accuracy: 0.965040922164917
Progress: 10 percent
Progress: 20 percent
Progress: 30 percent
Progress: 40 percent
Output: 0.21730580925941467, Target: 0.3980500102043152, Loss: 0.015895208343863487, Accuracy: 96.733154296875
Progress: 50 percent
Progress: 60 percent
Progress: 70 percent
Progress: 80 percent
Progress: 90 percent
Output: 0.7285764217376709, Target: 0.7772169709205627, Loss: 0.01874120719730854, Accuracy: 99.76341247558594

As you can see, the outputs are varied, and even looking like what they should be. However, when I run a two random examples (batch size 1) through the network, the output is the same (format here is input vector, output probability), namely, 0.5053:

Columns 0 to 12 
    0     1     1     1     0     1     1     1     1     1     0     0     1

Columns 13 to 25 
    1     0     1     1     1     1     0     1     1     0     1     1     1

Columns 26 to 38 
    0     0     1     1     1     1     1     0     1     0     1     1     0

Columns 39 to 51 
    1     1     1     0     1     1     1     0     1     0     0     1     1

Columns 52 to 64 
    0     1     0     1     1     0     1     1     0     1     0     1     0

Columns 65 to 77 
    0     1     1     1     1     1     0     1     0     1     1     0     1

Columns 78 to 90 
    1     1     0     1     1     1     1     1     0     0     1     1     0

Columns 91 to 103 
    1     1     1     1     0     1     1     0     1     1     1     0     0

Columns 104 to 116 
    1     1     1     1     1     0     1     0     1     1     0     1     1

Columns 117 to 121 
    0     0     1     0     1
[torch.FloatTensor of size 1x122]
, Variable containing:
 0.5053
[torch.FloatTensor of size 1x1]
)
(Variable containing:

Columns 0 to 12 
    0     1     1     1     0     1     1     1     1     1     0     0     1

Columns 13 to 25 
    1     0     1     1     1     1     0     1     1     0     1     1     1

Columns 26 to 38 
    0     0     1     1     1     1     1     0     1     0     1     1     0

Columns 39 to 51 
    1     1     1     0     1     1     1     1     1     0     0     1     1

Columns 52 to 64 
    0     1     1     1     1     0     1     1     0     1     1     1     0

Columns 65 to 77 
    0     1     1     1     1     1     0     1     0     1     1     0     1

Columns 78 to 90 
    1     1     0     1     1     1     1     1     0     0     1     1     0

Columns 91 to 103 
    1     1     1     0     1     1     1     0     1     0     1     0     0

Columns 104 to 116 
    1     1     1     1     1     0     1     0     1     0     1     0     0

Columns 117 to 121 
    0     1     0     0     1
[torch.FloatTensor of size 1x122]
, Variable containing:
 0.5053
[torch.FloatTensor of size 1x1]
)

Is batch size simply the issue? This problem occurs both when i pass singular training examples through the net and also when i pass singular organic examples through the net.

Probably you are mistakenly passing it always the same input? Without looking at the code it is difficult to help you more :confused:

That would make me look really stupid lmao but no my inputs are different… check the edit for some of my code (please)

I may have found your problem. It has to do with batch normalization. When evaluating your model, you have to tell your model to disable batch norm, by calling diceNN.train(False). Otherwise, applying batch normalization to a single sample will return a vector of zeros.

3 Likes