I have recurrent model that computes some output using the hidden state h_t of an LSTM and a m_t of a memory module. Then, in the next time step it should use that output as input after passing through another submodule. To be clear,
y_t = f(x_t, h_t, m_t)
h_t = g(x_t, y_(t-1), h_t)
m_t = p(x_t, h_t, m_t)
I know that the cuDNN implementation of LSTM with packed sequences works in an efficient way. However, it does not accept the output from previous step as input. I have thought about teacher forcing in which we give the ground truth instead of previous output. But in this way, I also have give the content of m_t which is an internal parameter. The only solution that I came up with is using LSTMCell and implementing other modules with 2 for loops (one for each sequence, one for each time step in the sequence).
Is there any more efficient way of implementing such a network? I want to accelerate it with a GPU, but do not want to create a kernel launch bottleneck.