Could someone show me how to find where torch._C._VariableFunctions.rnn_tanh is implemented?
(Any general hint or advice is appreciated)
My motivation is that I implemented my own rnn (lstm) using nn.Linear() as components of an nn module (I know I can call pytorch’s lstm, but I wanted to learn), and found out that my lstm is very slow on gpu most likely because i don’t run operation for all 4 gates in parallel. So I wanted to learn how pytorch’s rnn (lstm) library does it.
I started from RNNBase::forward() where it calls _rnn_impls[‘rnn_tanh’], which is torch._C._VariableFunctions.rnn_tanh.
For torch._C._VariableFunctions, I found where it is added as c++ extension here.
I tried to see if any of its tp_methods are about rnn, but could not figure out any further.
(Before this, by following RNNBase::init(), I saw that the 4 gates are placed consecutively in memory here)
I did find two related, very helpful threads, link1 link2.
But still could not figure out.