RNN with truncated BPTT

I implemented a basic RNN network from scratch on MNIST. The sequence X1,X2,…X28 is made of the 28 rows of a MNIST digit. The features are the 28 points in a row.

It works fine. Now, I would like to add the truncated backpropagation through time feature to it. I know detach() should be used somewhere but I can’t figure out where and how.

I am pretty sure it should be in the forward proparagation function so please find this function below :

from torch import tanh, softmax, matmul as m
...
Wxh = torch.empty(i_r, h_r, device=device, requires_grad=True)
Whh = torch.empty(h_r, h_r, device=device, requires_grad=True)
Why = torch.empty(h_r, o_r, device=device, requires_grad=True)
Bh = torch.empty(h_r, device=device, requires_grad=True)
Bhy = torch.empty(o_r, device=device, requires_grad=True)
pars = [Wxh, Whh, Why, Bh, Bhy]
for par in pars:
    torch.nn.init.normal_(par, mean=0, std=std)
...
def forward_prop(X):
    X = X.permute(1, 0, 2)    # (seq, batch, feature)
    ss, bs, _ = X.shape
    H = torch.zeros(bs, h_r, device=device)
    for k in range(ss):
        H = tanh(m(H, Whh) + m(X[k], Wxh) + Bh)
    Z = m(H, Why) + Bhy
    Y = softmax(Z, dim=1)
    return Y

How should I modify forward_prop() to add truncated BPTT?
Thanks in advance