Lstm chatbot not learning

Screenshot 2024-08-10 065109
output after 150 epochs:
Question: ‘hello’, Answer: 'i do ’
Question: ‘how was the traffic’, Answer: ‘a set for rules in the underlying entities in work’
Question: ‘what is a bug’, Answer: 'transparent system ’

# Define LSTM model class
class NCGM1(nn.Module):
    def __init__(self, num_emb, num_layers=1, emb_size=128, hidden_size=128):
        super(NCGM1, self).__init__()
        
        # Embedding layer
        self.embedding = nn.Embedding(num_emb, emb_size)

        # MLP layer for embedding
        self.mlp_emb = nn.Sequential(
            nn.Linear(emb_size, emb_size),
            nn.LayerNorm(emb_size),
            nn.ELU(),
            nn.Linear(emb_size, emb_size)
        )
        
        # LSTM layer
        self.lstm = nn.LSTM(
            input_size=emb_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,
            dropout=0.25
        )

        # MLP layer for output
        self.mlp_out = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 2),
            nn.LayerNorm(hidden_size // 2),
            nn.ELU(),
            nn.Dropout(0.5),
            nn.Linear(hidden_size // 2, num_emb)
        )
        
    def forward(self, input_seq, hidden_in, mem_in):
        # Embed input sequence
        input_embs = self.embedding(input_seq.type(torch.LongTensor).to(input_seq.device))
        # Pass through MLP for embedding
        input_embs = self.mlp_emb(input_embs)
                
        # Pass through LSTM layer
        output, (hidden_out, mem_out) = self.lstm(input_embs, (hidden_in, mem_in))
                
        # Pass through MLP for output
        return self.mlp_out(output), hidden_out, mem_out```

Train accuracy, Val accuracy, Val Loss? might help to diagnose

1 Like

Validation loss would help as Codenamics says, but also if your validation loss is still steadily decreasing, your training graph seems to indicate that your model can still learn more, the line hasn’t flattened.

yeah but still it can be overfit. we need val loss