I am trying to get a simple network to output the probability that a number is in one of three classes. These are, smaller than 1.1, between 1.1 and 1.5 and bigger than 1.5. I am using cross entropy loss with class labels of 0, 1 and 2, but cannot solve the problem.
Every time I train, the network outputs the maximum probability for class 2, regardless of input. The lowest loss I seem to be able to achieve is 0.9ish. Any advice on where I am going wrong would be greatly appreciated!! All code is below.
class gating_net(nn.Module):
def __init__(self, input_dim, output_dim):
super(gating_net, self).__init__()
self.linear1 = nn.Linear(input_dim, 32)
self.linear2 = nn.Linear(32, output_dim)
def forward(self, x):
# The original input (action) is used as the residual.
x = F.relu(self.linear1(x))
x = F.sigmoid(self.linear2(x))
return x
learning_rate = 0.01
batch_size = 64
epochs = 500
test = 1
gating_network = gating_net(1,3)
optimizer = torch.optim.SGD(gating_network.parameters(), lr=learning_rate, momentum=0.9)
scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=20, verbose=True)
for epoch in range (epochs):
input_ = []
label_ = []
for i in range (batch_size):
scale = random.randint(10,20)/10
input = scale
if scale < 1.1:
label = np.array([0])
elif 1.1 < scale < 1.5:
label = np.array([1])
else:
label = np.array([2])
input_.append(np.array([input]))
label_.append(label)
optimizer.zero_grad()
# get output from the model, given the inputs
output = gating_network.forward(torch.FloatTensor(input_))
old_label = torch.FloatTensor(label_)
# get loss for the predicted output
loss = nn.CrossEntropyLoss()(output, old_label.squeeze().long())
# get gradients w.r.t to parameters
loss.backward()
# update parameters
optimizer.step()
scheduler.step(loss)
print('epoch {}, loss {}'.format(epoch, loss.item()))
if loss.item() < 0.01:
print("########## Solved! ##########")
torch.save(mod_network.state_dict(), './supervised_learning/run_{}.pth'.format(test))
break
# save every 500 episodes
if epoch % 100 == 0:
torch.save(gating_network.state_dict(), './run_{}.pth'.format(test))
nn.CrossEntropyLoss expects the output to contain raw logits not probabilities, thus you would have to remove the last sigmoid.
Also, unrelated to the issue, but you should directly call the model via model(x) instead of the .forward method. Otherwise hooks could be ignored (not important for your current code, but for future use cases).
After removing the sigmoid, switching to optim.Adam and increasing the epochs to >3000, your model seems to work fine.