Training taking too long - basic LSTM POS Tagger

I am trying to train an LSTM POS tagger following https://pytorch.org/tutorials/beginner/nlp/sequence_models_tutorial.html. I wrote this yesterday and it took some 10s per epoch, but then I edited the code in places - I was running on the colab and forgot to save a copy- and now it gives the same accuracy but the first epoch takes around 60s and each epoch takes ~45s. The train() takes 40s in itself.

I tried adding torch.no_grad() but it does not seem to improve anything.

def train(epoch, model, loss_function, optimizer):
    train_loss = 0
    train_examples = 0
    # model = BasicPOSTagger(EMBEDDING_DIM, HIDDEN_DIM, len(word_to_idx), len(tag_to_idx))

    for sentence, tags in training_data:
        model.zero_grad()

        # get inputs ready
        sent_in = prepare_sequence(sentence, word_to_idx)
        targets = prepare_sequence(tags, tag_to_idx)

        # run forward pass
        tag_scores = model(sent_in)
        # compute loss, gradients, update params
        loss = loss_function(tag_scores, targets)

        loss.backward() 

        optimizer.step()     # tried with and without, model.zero_grad()   
        train_loss += loss.item()
        loss.detach_() #tried with and without this line
        train_examples += 1

    avg_train_loss = train_loss / train_examples
    # print(avg_train_loss)
    avg_val_loss, val_accuracy = evaluate(model, loss_function, optimizer)
        
    print("Epoch: {}/{}\tAvg Train Loss: {:.4f}\tAvg Val Loss: {:.4f}\t Val Accuracy: {:.0f}".format(epoch, 
                                                                      EPOCHS, 
                                                                      avg_train_loss, 
                                                                      avg_val_loss,
                                                                      val_accuracy))

def evaluate(model, loss_function, optimizer):
    val_loss = 0
    correct = 0.
    val_examples = 0
    with torch.no_grad():
        for sentence, tags in val_data:
           
            model.zero_grad()
            # get inputs ready
            sent_in = prepare_sequence(sentence, word_to_idx)
            targets = prepare_sequence(tags, tag_to_idx)

            # run forward pass
            tag_scores = model(sent_in)
 
            # compute loss, gradients, update params
            loss = loss_function(tag_scores, targets)
            optimizer.zero_grad()
            val_loss += loss
            loss.detach_()
            indices = torch.argmax(tag_scores,dim=1)
            pred = [tag for a in indices for (tag, index) in tag_to_idx.items() if index == a.item()]
            correct += sum([1 for i,j in zip(pred,tags) if i==j])/len(pred)
            val_examples += 1            
    val_accuracy = 100. * correct / val_examples
    avg_val_loss = val_loss / val_examples
    return avg_val_loss, val_accuracy

Calling :

for epoch in range(1, 30 + 1): 
    train(epoch, model, loss_function, optimizer)

Using NLLLoss() and SGD optimizer.
training_data is a list of lists where [i][0] is the tokenized sentence and [i][1] is its list of corresponding POS tags. I’m using the same parameters as before (embedding dimensions, learning rate etc) so I believe the issue is within the train() function but I cannot find where. Is my training wrong somewhere? Any help is very much appreciated, thankyou in advance.

I’m not quite sure what you’re asking. I take it the the training works fine in terms of that the loss goes down and the accuracy goes up, right?

You’re wondering about the performance now. 60s/epoch doesn’t mean anything without knowing the data and hardware. And that it was 10s/epoch before “some edits” doesn’t help much either. At a quick glance, your code looks fine, and you have the tutorial and the training seems to work as well. Not quite sure what you’re looking for now.

Of course, one main bottleneck is that you give the network only one sentence at a time (please correct me if I’m wrong). Using batches would speed up the training significantly.