GRU is not learning

Hi,

I’m trying to fit a GRU model on text data, to predict one of 26 labels. The problem is that the model is not really learning (accuracy is around 4%, which is just as random chance). Since I know that the problem is “learnable”, I suspect that there’s a bug in my code. I’d appreciate it if someone could take a look.

My data consists of (tokenized and word-encoded) 100K sentences per label (each sentence is mapped to one of 26 labels). My task is to predict the label of a new unseen sentence.
I tried several approaches, such as using a batch size > 1 together with padding, but the approach I’m sticking with right now is joining every 100 sentences to a single batch, so my samples become larger, and fit the model with 1 batch a time.

Model:

class GRU(nn.Module):
    def __init__(self, input_size, num_classes, batch_size):
        super(GRU, self).__init__()
        self.hidden_state = None
        self._batch_first = True
        self.batch_size = batch_size
        self.hidden_size = 256
        self.num_layers = 1
        embedding_dim = 256
        self.embedding = nn.Embedding(input_size, embedding_dim=embedding_dim)
        nn.init.uniform_(self.embedding.weight, -1.0, 1.0)
        self.gru = nn.GRU(embedding_dim, self.hidden_size, self.num_layers, batch_first=self._batch_first)
        self.fc = nn.Linear(self.hidden_size, num_classes)
    
    def init_hidden(self):
        self.hidden_state = torch.randn(self.num_layers, self.batch_size, self.hidden_size).to(device)

    def forward(self, x):
        embeds = self.embedding(x)
        out, self.hidden_state = self.gru(embeds, self.hidden_state)
        out = out[:, -1, :]
        out = self.fc(out)
        return out

   
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
learning_rate = 0.001
optimizer = lambda mdl: torch.optim.Adam(mdl.parameters(), lr=learning_rate)

model = RNN(len(vocab), len(encoded_lbls), BATCH_SIZE).to(device)

# RNN(
#   (embedding): Embedding(19353, 256)
#   (rnn): GRU(256, 256, batch_first=True)
#   (fc): Linear(in_features=256, out_features=26, bias=True)
# )

I tried different learning rates, and different losses such as NLLLoss with a LogSoftmax, but that made no difference.

Since I think that word ngrams are a good feature for this problem, I split each batch to word trigrams, and fed them to the model ngram by ngram, while resetting the hidden state before every batch:

model.train(mode=True)
for epoch in range(epochs):
    for label,encoded_txt in train_loader:
        encoded_txt, label = encoded_txt.to(device), label.to(device)
        model.init_hidden()
        output, loss, _ = evaluate(model, optim, encoded_txt, label, train=True)

    # validation eval...

Here’s the evaluate() function:

def evaluate(model, optim, txt, label, train=False):
    for ngram in txt.split(NGRAM_LEN):  # NGRAM_LEN = 3
        output = model(ngram)
    loss = criterion(output, label)    

    if train:
        optim.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.25)
        for p in model.parameters():
            p.data.add_(p.grad, alpha=-learning_rate)
        optim.step()
        
    accuracy = np.mean(np.array([item.item() for item in torch.argmax(output, dim=1)]) == label.cpu().numpy())

    return output, loss.item(), accuracy

This is what I’m getting after 10 epochs:

Epoch 0: Training Loss: 3.3762, Validation Loss: 3.4029, Validation Accuracy: 3.87%
Epoch 1: Training Loss: 3.3084, Validation Loss: 3.5362, Validation Accuracy: 3.89%
Epoch 2: Training Loss: 3.1202, Validation Loss: 3.8107, Validation Accuracy: 4.32%
Epoch 3: Training Loss: 2.9897, Validation Loss: 4.0599, Validation Accuracy: 4.57%
Epoch 4: Training Loss: 2.9118, Validation Loss: 4.3766, Validation Accuracy: 3.93%
Epoch 5: Training Loss: 2.9161, Validation Loss: 4.4962, Validation Accuracy: 4.23%
Epoch 6: Training Loss: 2.9117, Validation Loss: 4.7663, Validation Accuracy: 4.47%
Epoch 7: Training Loss: 2.9203, Validation Loss: 4.9078, Validation Accuracy: 4.55%
Epoch 8: Training Loss: 2.9253, Validation Loss: 5.1911, Validation Accuracy: 4.49%
Epoch 9: Training Loss: 2.9592, Validation Loss: 5.4946, Validation Accuracy: 4.23%

I’m hoping for at least 60% accuracy on the validation set, but as you can see it’s just as random chance.
The training loss is not really decreasing, and the validation loss is increasing.
I can’t say it’s overfitting since the training loss is pretty high so it’s not really learning.

Any ideas would be appreciated!