I tried to write my own simple classification model but encounter some very weird problem. as you can see below I tested the model on randomly generated values and expected to get different prediction labels from the softmax layer. However, I keep getting pretty much the same label on different instances. Moreover, the softmax vectors are quite similar. see below output examples
Can you please advise me and point the problem?
data = np.random.random((1000, 100))
labels = np.random.randint(13, size=1000)
labels = Variable(torch.from_numpy(labels).long())
class MyModel(nn.Module):
def __init__(self):
super(ZigZagModel, self).__init__()
self.input_layer = nn.Linear(100, 64)
self.middle_layer = nn.Linear(64, 32)
self.output_layer = nn.Linear(32, 13)
def forward(self, x):
x = F.relu(self.input_layer(x))
x = F.relu(self.middle_layer(x))
x = F.relu(self.output_layer(x))
return x
softmax_op = nn.Softmax()
loss_op = nn.CrossEntropyLoss()
model = MyModel()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
train_size = data.shape[0]
for i in range(0, 200000):
optimizer.zero_grad()
input_x = Variable(torch.Tensor(data[i % train_size, :]))
x = model(input_x)
x = x.unsqueeze(0)
softmax = softmax_op(x)
loss = loss_op(x, labels[i % train_size])
loss.backward()
optimizer.step()
# print the softmax
print(softmax.data.numpy())
# print the pred vs true label
print(np.argmax(softmax.data.numpy()), "vs", labels.data.numpy()[i%train_size])
Output:
[[ 0.0864875 0.07388693 0.07388693 0.08302182 0.08123572 0.07388693
0.07388693 0.07388693 0.07388693 0.08427258 0.07388693 0.07388693
0.07388693]]
0 vs 2
[[ 0.08640617 0.07394336 0.07493041 0.08375949 0.08198079 0.07394336
0.07394336 0.07394336 0.07394336 0.08137626 0.07394336 0.07394336
0.07394336]]
0 vs 11
[[ 0.08721988 0.07372069 0.0762369 0.08548821 0.07881035 0.07372069
0.07372069 0.07372069 0.07372069 0.08247914 0.07372069 0.07372069
0.07372069]]
0 vs 3
[[ 0.08426634 0.07351052 0.07711877 0.08805882 0.0829467 0.07351052
0.07351052 0.07351052 0.07351052 0.07952524 0.07351052 0.07351052
0.07351052]]
3 vs 6
[[ 0.08783976 0.07422064 0.07422064 0.08070375 0.07955535 0.07422064
0.07422064 0.07422064 0.07422064 0.08391535 0.07422064 0.07422064
0.07422064]]
0 vs 11
[[ 0.09056024 0.07391413 0.07391413 0.08014151 0.07759341 0.07391413
0.07391413 0.07391413 0.07391413 0.08647773 0.07391413 0.07391413
0.07391413]]
0 vs 6
[[ 0.08510514 0.07396153 0.07396153 0.08281226 0.08141673 0.07396153
0.07396153 0.07396153 0.07396153 0.08501204 0.07396153 0.07396153
0.07396153]]
0 vs 2
[[ 0.08639476 0.07395539 0.07487675 0.08390874 0.08175325 0.07395539
0.07395539 0.07395539 0.07395539 0.08142342 0.07395539 0.07395539
0.07395539]]
0 vs 3
[[ 0.08491787 0.0737853 0.07441458 0.08397128 0.08479271 0.0737853
0.0737853 0.0737853 0.0737853 0.08162116 0.0737853 0.0737853
0.0737853 ]]
0 vs 5
[[ 0.08539549 0.07431633 0.07440206 0.0826967 0.08185819 0.07431633
0.07431633 0.07438069 0.07431633 0.08071043 0.07465844 0.07431633
0.07431633]]
0 vs 10
[[ 0.08677328 0.07366005 0.07612024 0.08513239 0.07851321 0.07366005
0.07366005 0.07530047 0.07366005 0.08241236 0.07378771 0.07366005
0.07366005]]
0 vs 5
[[ 0.08555327 0.07402726 0.0745035 0.08586194 0.07993676 0.07402726
0.07402726 0.07402726 0.07402726 0.08192644 0.07402726 0.07402726
0.07402726]]