where to find the implementation of 'torch._C._VariableFunctions.rnn_tanh'?

Could someone show me how to find where torch._C._VariableFunctions.rnn_tanh is implemented?
(Any general hint or advice is appreciated)

My motivation is that I implemented my own rnn (lstm) using nn.Linear() as components of an nn module (I know I can call pytorch’s lstm, but I wanted to learn), and found out that my lstm is very slow on gpu most likely because i don’t run operation for all 4 gates in parallel. So I wanted to learn how pytorch’s rnn (lstm) library does it.

I started from RNNBase::forward() where it calls _rnn_impls[‘rnn_tanh’], which is torch._C._VariableFunctions.rnn_tanh.

For torch._C._VariableFunctions, I found where it is added as c++ extension here.

I tried to see if any of its tp_methods are about rnn, but could not figure out any further.

(Before this, by following RNNBase::init(), I saw that the 4 gates are placed consecutively in memory here)

I did find two related, very helpful threads, link1 link2.
But still could not figure out.

1 Like

Bump because I’m looking for a similar thing – can’t find the definition of ‘torch._C._VariableFunctions.gru’ anywhere in the source code.

All I know is that nn.GRU implements a different forward computation from that defined in one of the gru_cell functions, because with exactly the same parameters they give substantially different activities.

Found the same links and c++ files as OP, but no luck either. All I wanted was to access my GRU update and reset gate activations … :disappointed:

Not 100% sure, but it looks to me rnn_tanh is defined by macro expansion in Aten/native/RNN.cpp

Where the macro is defined here . It essentially dispatch to different implementations based on availability . For the cuDNN implementation, it should be in Aten/native/cudnn/RNN.cpp