Hey,
I want to implement a certain improvement for a Vanilla RNN.
The regular format of the network is
h_t = tanh( W_ih x_t + b_ih + W_hh h_t-1 + b_hh)
I want to add the next change:
h_t = tanh( W_ih x_t + b_ih + C(W_hh) h_t-1 + b_hh)
Where C is a linear function and C(W_hh) is a linear transformation of W_hh
How would you recommend implementing this change?
Thanks