Implement a linear transformation on one of the model weights

Hey,
I want to implement a certain improvement for a Vanilla RNN.
The regular format of the network is

h_t = tanh( W_ih x_t + b_ih + W_hh h_t-1 + b_hh)

I want to add the next change:

h_t = tanh( W_ih x_t + b_ih + C(W_hh) h_t-1 + b_hh)

Where C is a linear function and C(W_hh) is a linear transformation of W_hh

How would you recommend implementing this change?
Thanks