How to optimize a slow LSTM

I’m training a language model in colab on GPU but it’s really slow, How can I optimize my code?

This is my code:

  1. The training:
   for epoch in range(epochs):
        states = (torch.FloatTensor(2, bs, 200).\
                uniform_(-0.1, 0.1).to(device),\
            torch.FloatTensor(2, bs, 200).\
                uniform_(-0.1, 0.1).to(device))
        states = detach(states)
        for i, batch in enumerate(train_iter):
            model.train()     
            inputs, labels = batch.text, batch.target  
            inputs, labels = inputs.to(device), labels.to(device)     
            outputs, states = model(inputs, states)
            loss = loss_fn(outputs, labels.view(-1))
            model.zero_grad()
            loss.backward(retain_graph=True)
            clip_grad_norm_(model.parameters(), 0.5)
            optimizer.step()
            if not  i%100:
                    print('i {}, Epoch [{}/{}], Loss: {:.4f}, Perplexity: {:5.2f}'
                    .format(i, epoch+1, epochs, loss.item(), np.exp(loss.item())))
  1. The model:
class RNNModel(nn.Module):
    def __init__(self,
                 rnn_type,
                 vocab_size=10001,
                 n_layers=2,
                 n_hidden=200,
                 emb_size=300,
                 p_dropout=0.,
                 batch_size=bs):
        super(RNNModel, self).__init__()
        self.embedding = nn.Embedding.from_pretrained(TEXT.vocab.vectors, freeze=True)#(vocab_size, emb_size)
        #self.embedding.weight.data.copy_(TEXT.vocab.vectors)
        self.rnn = getattr(nn, rnn_type)(input_size=emb_size,
                                        num_layers=n_layers,
                                        hidden_size=n_hidden,
                                        dropout=p_dropout,
                                        batch_first=True)
        self.linear = nn.Linear(n_hidden,
                                vocab_size)
    
    def forward(self, x, h):
        x = self.embedding(x)
        output, (h, c) = self.rnn(x, h)
        output = output.reshape(output.size(0)*output.size(1), 
                                output.size(2))
        output = self.linear(output)
        return output, (h, c)

Python loops are very slow, you should try to use something to replace that. You can also use Touchscript to optimize it. Take a look here.