Cannot overfit MLP on random batch sample

I’ve come over a strange situation that is really annoying me but cannot still figure out what is going on?

I’m trying to train an MLP and have it overfit over a batch of randomly generated data. Initially I thought this would be easy but still trying and can’t figure out what’s going on. Tried a bunch of things, increasing model complexity, regularization, different learning rates, etc.

Yet, still cannot make the model overfit a random sample of labels.
At this point I’m all ears if anyone has any better ideas as to what the heck is going on.

MWE:

class MLPCharPred(nn.Module):
    def __init__(self, num_tokens, num_outputs, embedding_dim=64, hidden_dim=128):
        super(MLPCharPred, self).__init__()
        self.embedding = nn.Embedding(num_tokens, embedding_dim)
        self.fc1 = nn.Linear(embedding_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, 2*hidden_dim)
        self.fc3 = nn.Linear(2*hidden_dim, num_outputs)character

    def forward(self, x, mask=False):
        x = x.transpose(0, 1)
        embedded = self.embedding(x)
        x = F.relu(self.fc1(embedded))
        x = F.relu(self.fc2(x))
        output = self.fc3(x)

        if mask:
            # masking operations
            mask = torch.tril(torch.ones(output.shape[0], output.shape[0])).to(x.device)
            mask = mask.unsqueeze(0).unsqueeze(-1)
            mask = mask.expand(output.shape[1], -1, -1, -1)  # [32, 100, 100, 1]
            output = output.transpose(0, 1).unsqueeze(2)
            output = output.expand(-1, output.shape[1], output.shape[1], -1)
            output = output * mask
            output = output.sum(dim=2)
            output = output.transpose(0, 1).contiguous()
        return torch.sigmoid(output)
# Example dimensions
num_tokens = 3  # Token embedding size
d_model = 32
num_outputs = 3  # Number of classes or output dimension
seq_len = 100  # Maximum sequence length
batch_size = 128

model = ComplexMLPWithBatchNorm(num_tokens, num_outputs)


# Random data
x = torch.randint(0, 3, size=(batch_size, seq_len))  # Shape: [batch_size, seq_len]
y = torch.randint(0, num_outputs, size=(batch_size, seq_len))  # Shape: [batch_size, seq_len]

# Convert the target class indices to one-hot encoding
# targets_one_hot will have shape [batch_size, seq_len, num_classes]
targets_one_hot = torch.eye(num_outputs)[y]  # Using one-hot encoding
# assign more than one label per target
targets_one_hot[0][:20][:, 1] = torch.tensor([
    1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,])

opt = torch.optim.Adam(model.parameters(), lr=0.001)

model.train()

for epoch in range(1000):
    output = model(x)  # Forward pass
    output = output.transpose(0, 1).contiguous()  # Change shape to [batch_size, seq_len, num_classes]
    mask = (x != 0).float()
    loss = nn.MSELoss(reduction="none")(output.view(-1, num_outputs), targets_one_hot.view(-1, num_outputs)).mean(1)
    loss = (mask.view(-1) * loss.view(-1)).sum() / mask.sum()
    print(f"Epoch: {epoch}\tLoss: {loss.item()}")
    loss.backward()
    opt.step()
Loss: 0.22xxxx
Loss: 0.22xxxx
Loss: 0.22xxxx
Epoch: 1000, Loss: 0.22xxxx

plot the loss curves. overfitting can bee seen if the validation goes up, while train loss remains to go down