I implemented a basic RNN network from scratch on MNIST. The sequence X1,X2,…X28 is made of the 28 rows of a MNIST digit. The features are the 28 points in a row.
It works fine. Now, I would like to add the truncated backpropagation through time feature to it. I know detach() should be used somewhere but I can’t figure out where and how.
I am pretty sure it should be in the forward proparagation function so please find this function below :
from torch import tanh, softmax, matmul as m
...
Wxh = torch.empty(i_r, h_r, device=device, requires_grad=True)
Whh = torch.empty(h_r, h_r, device=device, requires_grad=True)
Why = torch.empty(h_r, o_r, device=device, requires_grad=True)
Bh = torch.empty(h_r, device=device, requires_grad=True)
Bhy = torch.empty(o_r, device=device, requires_grad=True)
pars = [Wxh, Whh, Why, Bh, Bhy]
for par in pars:
torch.nn.init.normal_(par, mean=0, std=std)
...
def forward_prop(X):
X = X.permute(1, 0, 2) # (seq, batch, feature)
ss, bs, _ = X.shape
H = torch.zeros(bs, h_r, device=device)
for k in range(ss):
H = tanh(m(H, Whh) + m(X[k], Wxh) + Bh)
Z = m(H, Why) + Bhy
Y = softmax(Z, dim=1)
return Y
How should I modify forward_prop() to add truncated BPTT?
Thanks in advance