I’m trying to fit a GRU model on text data, to predict one of 26 labels. The problem is that the model is not really learning (accuracy is around 4%, which is just as random chance). Since I know that the problem is “learnable”, I suspect that there’s a bug in my code. I’d appreciate it if someone could take a look.
My data consists of (tokenized and word-encoded) 100K sentences per label (each sentence is mapped to one of 26 labels). My task is to predict the label of a new unseen sentence.
I tried several approaches, such as using a batch size > 1 together with padding, but the approach I’m sticking with right now is joining every 100 sentences to a single batch, so my samples become larger, and fit the model with 1 batch a time.
class GRU(nn.Module): def __init__(self, input_size, num_classes, batch_size): super(GRU, self).__init__() self.hidden_state = None self._batch_first = True self.batch_size = batch_size self.hidden_size = 256 self.num_layers = 1 embedding_dim = 256 self.embedding = nn.Embedding(input_size, embedding_dim=embedding_dim) nn.init.uniform_(self.embedding.weight, -1.0, 1.0) self.gru = nn.GRU(embedding_dim, self.hidden_size, self.num_layers, batch_first=self._batch_first) self.fc = nn.Linear(self.hidden_size, num_classes) def init_hidden(self): self.hidden_state = torch.randn(self.num_layers, self.batch_size, self.hidden_size).to(device) def forward(self, x): embeds = self.embedding(x) out, self.hidden_state = self.gru(embeds, self.hidden_state) out = out[:, -1, :] out = self.fc(out) return out # Loss and optimizer criterion = nn.CrossEntropyLoss() learning_rate = 0.001 optimizer = lambda mdl: torch.optim.Adam(mdl.parameters(), lr=learning_rate) model = RNN(len(vocab), len(encoded_lbls), BATCH_SIZE).to(device) # RNN( # (embedding): Embedding(19353, 256) # (rnn): GRU(256, 256, batch_first=True) # (fc): Linear(in_features=256, out_features=26, bias=True) # )
I tried different learning rates, and different losses such as NLLLoss with a LogSoftmax, but that made no difference.
Since I think that word ngrams are a good feature for this problem, I split each batch to word trigrams, and fed them to the model ngram by ngram, while resetting the hidden state before every batch:
model.train(mode=True) for epoch in range(epochs): for label,encoded_txt in train_loader: encoded_txt, label = encoded_txt.to(device), label.to(device) model.init_hidden() output, loss, _ = evaluate(model, optim, encoded_txt, label, train=True) # validation eval...
def evaluate(model, optim, txt, label, train=False): for ngram in txt.split(NGRAM_LEN): # NGRAM_LEN = 3 output = model(ngram) loss = criterion(output, label) if train: optim.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 0.25) for p in model.parameters(): p.data.add_(p.grad, alpha=-learning_rate) optim.step() accuracy = np.mean(np.array([item.item() for item in torch.argmax(output, dim=1)]) == label.cpu().numpy()) return output, loss.item(), accuracy
This is what I’m getting after 10 epochs:
Epoch 0: Training Loss: 3.3762, Validation Loss: 3.4029, Validation Accuracy: 3.87% Epoch 1: Training Loss: 3.3084, Validation Loss: 3.5362, Validation Accuracy: 3.89% Epoch 2: Training Loss: 3.1202, Validation Loss: 3.8107, Validation Accuracy: 4.32% Epoch 3: Training Loss: 2.9897, Validation Loss: 4.0599, Validation Accuracy: 4.57% Epoch 4: Training Loss: 2.9118, Validation Loss: 4.3766, Validation Accuracy: 3.93% Epoch 5: Training Loss: 2.9161, Validation Loss: 4.4962, Validation Accuracy: 4.23% Epoch 6: Training Loss: 2.9117, Validation Loss: 4.7663, Validation Accuracy: 4.47% Epoch 7: Training Loss: 2.9203, Validation Loss: 4.9078, Validation Accuracy: 4.55% Epoch 8: Training Loss: 2.9253, Validation Loss: 5.1911, Validation Accuracy: 4.49% Epoch 9: Training Loss: 2.9592, Validation Loss: 5.4946, Validation Accuracy: 4.23%
I’m hoping for at least 60% accuracy on the validation set, but as you can see it’s just as random chance.
The training loss is not really decreasing, and the validation loss is increasing.
I can’t say it’s overfitting since the training loss is pretty high so it’s not really learning.
Any ideas would be appreciated!