Every single Model Output in Batch the same, Model does not does not converge or overfit even with 10 samples

When creating a Model i came across the effect that the Model would always converge to a state, where every sample in a batch would have the same output for the Model independent of the actual label. But with a batchsize of 1 the Model does also not converge/overfit.

Example of what i mean with batch_size = 4 after 783 epochs with only 10 samples (so the Model should overfit in my opinion):

epoch: 783 step: 5/5 loss = 0.701137363910675
Correct Labels:
tensor([[0., 1.],
        [1., 0.]])
Outputs of Model:
tensor([[-0.0208,  0.2322],
        [-0.0208,  0.2322]], grad_fn=<AddmmBackward0>)

What i already tried:

  • Increased/Decreased number of Layers
  • Increased/Decreased Learning rate
  • Changed from CrossEntropyLoss to BCE (thats why in the below example i only have two different labels, in my Original Model i had more Labels, thats why i used CrossEntropyLoss)
  • verified that the dataset works correctly (in the minimal Example below its just example data but still creates the Problem i have) (Original data is also much more complex and divers
  • reduced Feature_size (again in the minimal Example below are already reduced to only Feature_size of 10 and still has the issue)
  • eliminated some other complexities not present in the minimalistic Version below
  • reduced samples so the Model should overfit (again see Version below)

I’m pretty new to Pytorch, so its very likely i made some simple mistake.
Here i have the most minimalistic Version i could create of my Problem, what am i missing?

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import math


class Dataset(Dataset):

    def __getitem__(self, index):
        if index == 0:
            feature_vector = [178.0500, 178.3200, 177.7100, 178.2300, 177.2700, 176.5800, 177.8900,
                              177.4400, 177.3900, 177.2600]
            label = [1., 0.]
        if index == 1:
            feature_vector = [178.3200, 177.7100, 178.2300, 177.2700, 176.5800, 177.8900, 177.4400,
                              177.3900, 177.2600, 178.4900]
            label = [0., 1.]
        if index == 2:
            feature_vector = [177.7100, 178.2300, 177.2700, 176.5800, 177.8900, 177.4400, 177.3900,
                              177.2600, 178.4900, 178.6800]
            label = [0., 1.]
        if index == 3:
            feature_vector = [178.2300, 177.2700, 176.5800, 177.8900, 177.4400, 177.3900, 177.2600,
                              178.4900, 178.6800, 178.4900]
            label = [0., 1.]
        if index == 4:
            feature_vector = [177.2700, 176.5800, 177.8900, 177.4400, 177.3900, 177.2600, 178.4900,
                              178.6800, 178.4900, 178.5100]
            label = [0., 1.]
        if index == 5:
            feature_vector = [176.5800, 177.8900, 177.4400, 177.3900, 177.2600, 178.4900, 178.6800,
                              178.4900, 178.5100, 175.9700]
            label = [0., 1.]
        if index == 6:
            feature_vector = [177.8900, 177.4400, 177.3900, 177.2600, 178.4900, 178.6800, 178.4900,
                              178.5100, 175.9700, 176.7300]
            label = [1., 0.]
        if index == 7:
            feature_vector = [177.4400, 177.3900, 177.2600, 178.4900, 178.6800, 178.4900, 178.5100,
                              175.9700, 176.7300, 176.8200]
            label = [0., 1.]
        if index == 8:
            feature_vector = [177.3900, 177.2600, 178.4900, 178.6800, 178.4900, 178.5100, 175.9700,
                              176.7300, 176.8200, 177.4800]
            label = [0., 1.]
        if index == 9:
            feature_vector = [177.2600, 178.4900, 178.6800, 178.4900, 178.5100, 175.9700, 176.7300,
                              176.8200, 177.4800, 175.4200]
            label = [1., 0.]
        feature_vector = torch.tensor(feature_vector, dtype=torch.float32, device="cpu")
        label = torch.tensor(label, dtype=torch.float32, device="cpu")
        return feature_vector, label

    def __len__(self):
        length = 10
        return length


class LinNet(nn.Module):
    def __init__(self):
        super(LinNet, self).__init__()
        self.lin1 = nn.Linear(10, 10)
        self.lin2 = nn.Linear(10, 10)
        self.lin3 = nn.Linear(10, 10)
        self.lin4 = nn.Linear(10, 10)
        self.lin5 = nn.Linear(10, 10)
        self.lin6 = nn.Linear(10, 10)
        self.lin7 = nn.Linear(10, 10)
        self.lin8 = nn.Linear(10, 10)
        self.lin9 = nn.Linear(10, 10)
        self.lin10 = nn.Linear(10, 2)

    def forward(self, x):
        x = torch.relu(self.lin1(x))
        x = torch.relu(self.lin2(x))
        x = torch.relu(self.lin3(x))
        x = torch.relu(self.lin4(x))
        x = torch.relu(self.lin5(x))
        x = torch.relu(self.lin6(x))
        x = torch.relu(self.lin7(x))
        x = torch.relu(self.lin8(x))
        x = torch.relu(self.lin9(x))
        x = self.lin10(x)
        return x


if __name__ == '__main__':
    batch_size = 4

    dataset = Dataset()

    dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=False, num_workers=1)

    model = LinNet()

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.0001)

    num_epochs = 10000
    total_samples = len(dataset)
    total_iterations = math.ceil(total_samples / batch_size)

    for epoch in range(num_epochs):
        for i, data in enumerate(dataloader):
            # get the inputs
            inputs, labels = data
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            # for name, param in model.named_parameters():
            #     print(name, param.grad.abs().sum())
            optimizer.step()

            # print statistics
            print("Labels:")
            print(labels)
            print("Outputs:")
            print(outputs)
            print(f"epoch: {epoch} step: {i + 1}/{total_iterations} loss = {loss.item()}")

I would try to remove the bias from your data by normalizing it.
Also, the model is quite deep (which makes the convergence hard) so I would also reduce the number of layers.
You could also change the learning rate and try out different optmizers such as Adam.
After a few changes (and in particular using only 2 layers and normalizing the data) the model overfits the dataset:

class LinNet(nn.Module):
    def __init__(self):
        super(LinNet, self).__init__()
        self.lin9 = nn.Linear(10, 10)
        self.lin10 = nn.Linear(10, 2)

    def forward(self, x):
        x = torch.relu(self.lin9(x))
        x = self.lin10(x)
        return x


if __name__ == '__main__':
    batch_size = 4

    dataset = Dataset()

    dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=False, num_workers=1)

    model = LinNet()

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    num_epochs = 10000
    total_samples = len(dataset)
    total_iterations = math.ceil(total_samples / batch_size)

    for epoch in range(num_epochs):
        for i, data in enumerate(dataloader):
            # get the inputs
            inputs, labels = data
            inputs = inputs - 177.6438
            inputs = inputs / 0.7894
            optimizer.zero_grad()
            outputs = model(inputs)
            labels = labels.argmax(1)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # print statistics
            print("Labels:")
            print(labels)
            print("Outputs:")
            print(outputs)
            print(f"epoch: {epoch} step: {i + 1}/{total_iterations} loss = {loss.item()}")

Output:

...
epoch: 401 step: 1/3 loss = 0.007887040264904499
Labels:
tensor([1, 1, 0, 1])
Outputs:
tensor([[-3.4531,  2.4714],
        [-3.5186,  2.1545],
        [ 2.1726, -2.6982],
        [-3.0720,  2.0631]], grad_fn=<AddmmBackward0>)

Note that SGD would also work but is slower in this example.