I implemented a basic RNN network from scratch on MNIST. The sequence X1,X2,…X28 is made of the 28 rows of a MNIST digit. The features are the 28 points in a row.
It works fine. Now, I would like to add the truncated backpropagation through time feature to it. I know detach() should be used somewhere but I can’t figure out where and how.
I am pretty sure it should be in the forward proparagation function so please find this function below :
from torch import tanh, softmax, matmul as m ... Wxh = torch.empty(i_r, h_r, device=device, requires_grad=True) Whh = torch.empty(h_r, h_r, device=device, requires_grad=True) Why = torch.empty(h_r, o_r, device=device, requires_grad=True) Bh = torch.empty(h_r, device=device, requires_grad=True) Bhy = torch.empty(o_r, device=device, requires_grad=True) pars = [Wxh, Whh, Why, Bh, Bhy] for par in pars: torch.nn.init.normal_(par, mean=0, std=std) ... def forward_prop(X): X = X.permute(1, 0, 2) # (seq, batch, feature) ss, bs, _ = X.shape H = torch.zeros(bs, h_r, device=device) for k in range(ss): H = tanh(m(H, Whh) + m(X[k], Wxh) + Bh) Z = m(H, Why) + Bhy Y = softmax(Z, dim=1) return Y
How should I modify forward_prop() to add truncated BPTT?
Thanks in advance