Train loss doesn´t decrease for regression task

Hello,

I have kind of a time-series prediction where the output shape is variable. Therefore I did interpolation on my dataset to strech or compress the data into equal sized length of 10. Note, that the sum of the values in that tensor sum up to 1 - they are relative values.

After running training for 50 epochs the epoch loss stays the same and does not decrease. I know there are several topics about that problem, but nothing is releated to a ‘multi regression output’ where the target values sum up to 1.

My custom dataset:

class BADataset(data.Dataset):
    def __init__(self, X, y):
        self.X = torch.tensor(X.to_numpy().astype(np.float32))
        self.y = torch.tensor(y.to_numpy().astype(np.float32))

    def __len__(self):
        return len(self.X)

    def __getitem__(self, index):
        return self.X[index], self.y[index]

My network:

class Net(nn.Module):
    def __init__(self, inputsize, hiddensize, outputsize, dropout):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(inputsize, hiddensize)
        self.fc2 = nn.Linear(hiddensize, outputsize)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.softmax(self.fc2(x), dim=1)
        return x

Training process:

    for epoch in range(EPOCHS):
        net.train()
        epoch_loss = list()
        cum_loss_train = list()
        running_loss = 0.0
        for idx, data in enumerate(train_loader):
            inputs, labels = data
            inputs = Variable(inputs.to(device))
            labels = Variable(labels.to(device))
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss.append(loss.item())
        cum_loss_train.append(np.mean(epoch_loss))
        print(f'{datetime.datetime.now()} | EPOCH: {epoch} | LOSS: {np.round(np.mean(epoch_loss), 5)}')

Parameters:

    BATCH_SIZE = 8
    LEARNING_RATE = 0.05
    MOMENTUM = 0.00
    EPOCHS = 50
    INPUT_SIZE = 40
    HIDDEN_SIZE = 90
    OUTPUT_SIZE = 10
    DROPOUT = 0.8
    EVAL_STEP = 1
    criterion = nn.MSELoss()
    optimizer = optim.SGD(net.parameters(), lr=LEARNING_RATE)

I already played around with the parameters. Changed the learning rate, hidden layer size, dropout value, batch size etc. But nothing solves the problem.
I also tried the Adam optimizer and changed the loss function to L1Loss

Also very strange is the model output after training for a view epochs:

tensor([0., 1., 0., 0., 0., 0., 0., 0., 0., 0.], grad_fn=<UnbindBackward0>)
tensor([0., 1., 0., 0., 0., 0., 0., 0., 0., 0.], grad_fn=<UnbindBackward0>)
tensor([0., 1., 0., 0., 0., 0., 0., 0., 0., 0.], grad_fn=<UnbindBackward0>)
tensor([0., 1., 0., 0., 0., 0., 0., 0., 0., 0.], grad_fn=<UnbindBackward0>)
tensor([0., 1., 0., 0., 0., 0., 0., 0., 0., 0.], grad_fn=<UnbindBackward0>)
tensor([0., 1., 0., 0., 0., 0., 0., 0., 0., 0.], grad_fn=<UnbindBackward0>)

The model always sets the second output to 1 and the rest to 0.
A reasonable output of the segments should be like that:

-0.00103,0.16046,0.20881,0.57528,0.02601,-0.00107,0.00423,0.02641,0.00045,0.00045
0.01369,0.22102,0.67908,0.04893,0.00032,0.00958,0.02267,0.00380,0.00045,0.00045
0.16022,0.20850,0.57443,0.02597,-0.00107,0.00423,0.02637,0.00045,0.00045,0.00045
0.22323,0.68586,0.04942,0.00032,0.00968,0.02290,0.00723,0.00045,0.00045,0.00045
0.24072,0.66320,0.02998,-0.00123,0.00488,0.05265,0.00825,0.00052,0.00052,0.00052
0.85768,0.06180,0.00040,0.01210,-0.00040,0.05767,0.00904,0.00057,0.00057,0.00057
0.86787,0.03924,-0.00161,0.01325,-0.00048,0.06890,0.01080,0.00068,0.00068,0.00068
0.40020,0.00262,0.08491,0.07184,-0.00261,0.37345,0.05853,0.00368,0.00368,0.00368

Kinda like a multi label classification where the probability is distributed over 10 classes.

Finally here is the output of the training loss.

2023-02-02 16:02:36.476878 | EPOCH: 0 | LOSS: 522.06665
2023-02-02 16:02:51.127700 | EPOCH: 1 | LOSS: 522.06669
2023-02-02 16:03:06.418139 | EPOCH: 2 | LOSS: 522.06669
2023-02-02 16:03:21.640622 | EPOCH: 3 | LOSS: 522.06669
2023-02-02 16:03:36.421470 | EPOCH: 4 | LOSS: 522.06669
2023-02-02 16:03:50.987930 | EPOCH: 5 | LOSS: 522.06669
2023-02-02 16:04:05.824788 | EPOCH: 6 | LOSS: 522.06669
2023-02-02 16:04:20.751109 | EPOCH: 7 | LOSS: 522.06669
2023-02-02 16:04:35.314357 | EPOCH: 8 | LOSS: 522.06669
2023-02-02 16:04:49.541290 | EPOCH: 9 | LOSS: 522.06669
2023-02-02 16:05:04.630747 | EPOCH: 10 | LOSS: 522.06669
2023-02-02 16:05:19.952308 | EPOCH: 11 | LOSS: 522.06669
2023-02-02 16:05:35.211317 | EPOCH: 12 | LOSS: 522.06669
2023-02-02 16:05:50.412194 | EPOCH: 13 | LOSS: 522.06669
2023-02-02 16:06:05.951134 | EPOCH: 14 | LOSS: 522.06669
2023-02-02 16:06:22.523799 | EPOCH: 15 | LOSS: 522.06669
2023-02-02 16:06:37.155365 | EPOCH: 16 | LOSS: 522.06669
2023-02-02 16:06:51.703793 | EPOCH: 17 | LOSS: 522.06669
2023-02-02 16:07:06.309624 | EPOCH: 18 | LOSS: 522.06669

Any suggestions about the problem here?

Your outputs contain negative values, so they don’t seem to represent probabilities.
Note that applying softmax on your output tensor will not allow to learn any negative values, so you should at least revisit this issue.

Generally, I would try to overfit a small dataset (e.g. just 10 samples) and make sure your model is able to do so.
Using random input and target data and Adam seems to work fine:

BATCH_SIZE = 8
LEARNING_RATE = 0.05
MOMENTUM = 0.00
EPOCHS = 50
INPUT_SIZE = 40
HIDDEN_SIZE = 90
OUTPUT_SIZE = 10
DROPOUT = 0.8
EVAL_STEP = 1


class BADataset(Dataset):
    def __init__(self):
        self.X = torch.randn(100, INPUT_SIZE)
        self.y = F.softmax(torch.randn(100, OUTPUT_SIZE), dim=1)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, index):
        return self.X[index], self.y[index]


class Net(nn.Module):
    def __init__(self, inputsize, hiddensize, outputsize, dropout):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(inputsize, hiddensize)
        self.fc2 = nn.Linear(hiddensize, outputsize)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.softmax(self.fc2(x), dim=1)
        return x

net = Net(INPUT_SIZE, HIDDEN_SIZE, OUTPUT_SIZE, DROPOUT)
criterion = nn.MSELoss()
optimizer = optim.Adam(net.parameters(), lr=1e-3)
dataset = BADataset()
train_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)


for epoch in range(200):
    net.train()
    epoch_loss = list()
    cum_loss_train = list()
    running_loss = 0.0
    for idx, data in enumerate(train_loader):
        inputs, labels = data
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss.append(loss.item())
    cum_loss_train.append(np.mean(epoch_loss))
    print(f'EPOCH: {epoch} | LOSS: {np.round(np.mean(epoch_loss), 5)}')
# ...
# EPOCH: 199 | LOSS: 6e-05

print(outputs)
# tensor([[0.0347, 0.3176, 0.0261, 0.2450, 0.0678, 0.0376, 0.0277, 0.1877, 0.0315,
#          0.0244],
#         [0.0276, 0.0988, 0.2840, 0.0491, 0.0398, 0.0820, 0.1149, 0.0317, 0.0231,
#          0.2488],
#         [0.1692, 0.2242, 0.0874, 0.0373, 0.0676, 0.0474, 0.0997, 0.1756, 0.0523,
#          0.0393],
#         [0.2727, 0.0369, 0.0347, 0.0780, 0.3089, 0.0728, 0.0360, 0.0683, 0.0725,
#          0.0192]], grad_fn=<SoftmaxBackward0>)

print(labels)
# tensor([[0.0343, 0.3356, 0.0185, 0.2383, 0.0662, 0.0299, 0.0299, 0.1859, 0.0352,
#          0.0264],
#         [0.0169, 0.1010, 0.2851, 0.0449, 0.0439, 0.0829, 0.1194, 0.0378, 0.0152,
#          0.2528],
#         [0.1673, 0.2302, 0.0882, 0.0353, 0.0684, 0.0469, 0.0998, 0.1708, 0.0554,
#          0.0379],
#         [0.2811, 0.0313, 0.0261, 0.0793, 0.3020, 0.0771, 0.0319, 0.0772, 0.0733,
#          0.0207]])

although the probability output in combination with a softmax as the output layer is uncommon.