In TensorFlow, there is two function get_initial_state(self) and single_step(self, context, word_id)
where (Link: https://github.com/kaldi-asr/kaldi/blob/2c7e78f0757120a965e671f2843c7c07ac7141d4/egs/wsj/s5/steps/tfrnnlm/lstm.py#L115-L139 )
@tf.function
def get_initial_state(self):
“”“Exported function which emits zeroed RNN context vector.”""
# This seems a bug in TensorFlow, but passing tf.int32 makes the state tensor also int32.
fake_input = tf.constant(0, dtype=tf.float32, shape=[1, 1])
initial_state = tf.stack(self.rnn.get_initial_state(fake_input))
return {“initial_state”: initial_state}
@tf.function
def single_step(self, context, word_id):
“”“Exported function which perform one step of the RNN model.”""
rnn = tf.keras.layers.RNN(self.cells, return_state=True)
context = tf.unstack(context)
context = [tf.unstack© for c in context]
inputs = self.embedding(word_id)
rnn_out_and_states = rnn(inputs, initial_state=context)
rnn_out = rnn_out_and_states[0]
rnn_states = tf.stack(rnn_out_and_states[1:])
logits = self.fc(rnn_out)
output = self.get_score(logits)
log_prob = output[0, word_id[0, 0]]
return {"log_prob": log_prob, "rnn_states": rnn_states, "rnn_out": rnn_out}
where these functions are used in Kaldi Language Model rescoring (which is present in the same code from line 230-236)
Export
print(“Saving model to %s.” % FLAGS.save_path)
spec = [tf.TensorSpec(shape=[config.num_layers, 2, 1, config.hidden_size], dtype=data_type(), name=“context”),
tf.TensorSpec(shape=[1, 1], dtype=tf.int32, name=“word_id”)]
cfunc = model.single_step.get_concrete_function(*spec)
cfunc2 = model.get_initial_state.get_concrete_function()
tf.saved_model.save(model, FLAGS.save_path, signatures={“single_step”: cfunc, “get_initial_state”: cfunc2})
To which I did in PyTorch LM (standard example)
(Link: https://github.com/pytorch/examples/blob/master/word_language_model/main.py)
traced_script_module = torch.jit.trace(model, (train_data, hidden))
traced_script_module.save(‘newmodel.pt’)
Does This do the same as Tesorflow both functions do?
Can I get initial_state, rnn_state, and rnn_output?
Thank you in advance (Very Important it is)