Backward pass is outrageously slow

Hello,
I’m trying to implement text generation RNN-based with sequence of different length with padding and masked crossed-entropy loss. Here a snippet of critical code.
Each backward step of the loss takes 30s against less than 2 for all above.

Thanks in advance for any suggestion !

import csv
import numpy as np
import logging
import time
import string
from itertools import chain

import torch
import torch.nn as nn
import torch.optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from pathlib import Path
from textloader import *
# from generate import *
import logging
logging.basicConfig(level=logging.INFO)


def maskedCrossEntropy(output, target, padcar):
    mask = target != padcar
    loss = torch.nn.CrossEntropyLoss(reduction="none")(output.permute(0,2,1), target.long()) * mask
    return loss.sum() / mask.sum()

class RNN(nn.Module):
    def __init__(self, latent, dim, out):
        super().__init__()
        self.latent = latent
        self.dim = dim
        self.out = out
        self.hidden_state = torch.tensor(latent)
        self.lin_hs = nn.Linear(latent, latent)
        self.lin_ft = nn.Linear(dim, latent)
        self.lin_dec = nn.Linear(latent, out) 

    def decode(self, hs):

        d = self.lin_dec(hs)
        return d

    def forward(self, batch, hs):
        l = []
        for i in range(batch.shape[0]):
            hs = self.one_step(batch[i,:], hs)
            l.append(hs)
        return torch.stack(l)
    
    def one_step(self, batch, hs):
        return torch.tanh(self.lin_hs(hs) + self.lin_ft(batch.clone()))

speech = ""
with open('data/full_speech.txt') as f:
    while True:
        c = f.read(1)
        speech += c
        if not c:
            break

LR = 10e-3
SEQ_LEN = 100
PRED_LEN = 10
LATENT_DIM = 50
BATCH_SIZE = 500
EPOCH_RANGE = 5


embedding = nn.Embedding(len(id2lettre),50)
speech_dataset = TextDataset(speech)
speech_dataloader = DataLoader(speech_dataset, BATCH_SIZE, shuffle=True, drop_last=True, collate_fn=collate_fn)
model = RNN(LATENT_DIM, len(id2lettre), len(id2lettre))
optim = torch.optim.Adam(model.parameters(), lr = LR)
loss = torch.nn.CrossEntropyLoss()
hs = torch.zeros(BATCH_SIZE, LATENT_DIM)

for epoch in range(EPOCH_RANGE):
    print(epoch)
    i = 0
    for x in speech_dataloader:

        optim.zero_grad()
        hst = model(embedding(x.long()), hs)
        hst = model.decode(hst)
        l = maskedCrossEntropy(hst, x, PAD_IX)
        l.backward()
        optim.step()