In TensorFlow, there is two function get_initial_state(self) and single_step(self, context, word_id)

where (Link: https://github.com/kaldi-asr/kaldi/blob/2c7e78f0757120a965e671f2843c7c07ac7141d4/egs/wsj/s5/steps/tfrnnlm/lstm.py#L115-L139 )

@tf.function

def get_initial_state(self):

“”“Exported function which emits zeroed RNN context vector.”""

# This seems a bug in TensorFlow, but passing tf.int32 makes the state tensor also int32.

fake_input = tf.constant(0, dtype=tf.float32, shape=[1, 1])

initial_state = tf.stack(self.rnn.get_initial_state(fake_input))

return {“initial_state”: initial_state}

@tf.function

def single_step(self, context, word_id):

“”“Exported function which perform one step of the RNN model.”""

rnn = tf.keras.layers.RNN(self.cells, return_state=True)

context = tf.unstack(context)

context = [tf.unstack© for c in context]

```
inputs = self.embedding(word_id)
rnn_out_and_states = rnn(inputs, initial_state=context)
rnn_out = rnn_out_and_states[0]
rnn_states = tf.stack(rnn_out_and_states[1:])
logits = self.fc(rnn_out)
output = self.get_score(logits)
log_prob = output[0, word_id[0, 0]]
return {"log_prob": log_prob, "rnn_states": rnn_states, "rnn_out": rnn_out}
```

where these functions are used in Kaldi Language Model rescoring (which is present in the same code from line 230-236)

# Export

print(“Saving model to %s.” % FLAGS.save_path)

spec = [tf.TensorSpec(shape=[config.num_layers, 2, 1, config.hidden_size], dtype=data_type(), name=“context”),

tf.TensorSpec(shape=[1, 1], dtype=tf.int32, name=“word_id”)]

cfunc = model.single_step.get_concrete_function(*spec)

cfunc2 = model.get_initial_state.get_concrete_function()

tf.saved_model.save(model, FLAGS.save_path, signatures={“single_step”: cfunc, “get_initial_state”: cfunc2})

To which I did in PyTorch LM (standard example)

(Link: https://github.com/pytorch/examples/blob/master/word_language_model/main.py)

traced_script_module = torch.jit.trace(model, (train_data, hidden))

traced_script_module.save(‘newmodel.pt’)

Does This do the same as Tesorflow both functions do?

Can I get initial_state, rnn_state, and rnn_output?

Thank you in advance (Very Important it is)