Hey,

I want to implement a certain improvement for a Vanilla RNN.

The regular format of the network is

```
h_t = tanh( W_ih x_t + b_ih + W_hh h_t-1 + b_hh)
```

I want to add the next change:

```
h_t = tanh( W_ih x_t + b_ih + C(W_hh) h_t-1 + b_hh)
```

Where C is a linear function and C(W_hh) is a linear transformation of W_hh

How would you recommend implementing this change?

Thanks