I’ve come over a strange situation that is really annoying me but cannot still figure out what is going on?
I’m trying to train an MLP and have it overfit over a batch of randomly generated data. Initially I thought this would be easy but still trying and can’t figure out what’s going on. Tried a bunch of things, increasing model complexity, regularization, different learning rates, etc.
Yet, still cannot make the model overfit a random sample of labels.
At this point I’m all ears if anyone has any better ideas as to what the heck is going on.
MWE:
class MLPCharPred(nn.Module):
def __init__(self, num_tokens, num_outputs, embedding_dim=64, hidden_dim=128):
super(MLPCharPred, self).__init__()
self.embedding = nn.Embedding(num_tokens, embedding_dim)
self.fc1 = nn.Linear(embedding_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, 2*hidden_dim)
self.fc3 = nn.Linear(2*hidden_dim, num_outputs)character
def forward(self, x, mask=False):
x = x.transpose(0, 1)
embedded = self.embedding(x)
x = F.relu(self.fc1(embedded))
x = F.relu(self.fc2(x))
output = self.fc3(x)
if mask:
# masking operations
mask = torch.tril(torch.ones(output.shape[0], output.shape[0])).to(x.device)
mask = mask.unsqueeze(0).unsqueeze(-1)
mask = mask.expand(output.shape[1], -1, -1, -1) # [32, 100, 100, 1]
output = output.transpose(0, 1).unsqueeze(2)
output = output.expand(-1, output.shape[1], output.shape[1], -1)
output = output * mask
output = output.sum(dim=2)
output = output.transpose(0, 1).contiguous()
return torch.sigmoid(output)
# Example dimensions
num_tokens = 3 # Token embedding size
d_model = 32
num_outputs = 3 # Number of classes or output dimension
seq_len = 100 # Maximum sequence length
batch_size = 128
model = ComplexMLPWithBatchNorm(num_tokens, num_outputs)
# Random data
x = torch.randint(0, 3, size=(batch_size, seq_len)) # Shape: [batch_size, seq_len]
y = torch.randint(0, num_outputs, size=(batch_size, seq_len)) # Shape: [batch_size, seq_len]
# Convert the target class indices to one-hot encoding
# targets_one_hot will have shape [batch_size, seq_len, num_classes]
targets_one_hot = torch.eye(num_outputs)[y] # Using one-hot encoding
# assign more than one label per target
targets_one_hot[0][:20][:, 1] = torch.tensor([
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,])
opt = torch.optim.Adam(model.parameters(), lr=0.001)
model.train()
for epoch in range(1000):
output = model(x) # Forward pass
output = output.transpose(0, 1).contiguous() # Change shape to [batch_size, seq_len, num_classes]
mask = (x != 0).float()
loss = nn.MSELoss(reduction="none")(output.view(-1, num_outputs), targets_one_hot.view(-1, num_outputs)).mean(1)
loss = (mask.view(-1) * loss.view(-1)).sum() / mask.sum()
print(f"Epoch: {epoch}\tLoss: {loss.item()}")
loss.backward()
opt.step()
Loss: 0.22xxxx
Loss: 0.22xxxx
Loss: 0.22xxxx
Epoch: 1000, Loss: 0.22xxxx