Rewriting the code of Tensorflow (Need Review) - Just 2 lines only

In TensorFlow, there is two function get_initial_state(self) and single_step(self, context, word_id)

where (Link: )

def get_initial_state(self):
“”“Exported function which emits zeroed RNN context vector.”""
# This seems a bug in TensorFlow, but passing tf.int32 makes the state tensor also int32.
fake_input = tf.constant(0, dtype=tf.float32, shape=[1, 1])
initial_state = tf.stack(self.rnn.get_initial_state(fake_input))
return {“initial_state”: initial_state}

def single_step(self, context, word_id):
“”“Exported function which perform one step of the RNN model.”""
rnn = tf.keras.layers.RNN(self.cells, return_state=True)
context = tf.unstack(context)
context = [tf.unstack© for c in context]

inputs = self.embedding(word_id)
rnn_out_and_states = rnn(inputs, initial_state=context)

rnn_out = rnn_out_and_states[0]
rnn_states = tf.stack(rnn_out_and_states[1:])

logits = self.fc(rnn_out)
output = self.get_score(logits)
log_prob = output[0, word_id[0, 0]]
return {"log_prob": log_prob, "rnn_states": rnn_states, "rnn_out": rnn_out}

where these functions are used in Kaldi Language Model rescoring (which is present in the same code from line 230-236)


print(“Saving model to %s.” % FLAGS.save_path)
spec = [tf.TensorSpec(shape=[config.num_layers, 2, 1, config.hidden_size], dtype=data_type(), name=“context”),
tf.TensorSpec(shape=[1, 1], dtype=tf.int32, name=“word_id”)]
cfunc = model.single_step.get_concrete_function(*spec)
cfunc2 = model.get_initial_state.get_concrete_function(), FLAGS.save_path, signatures={“single_step”: cfunc, “get_initial_state”: cfunc2})

To which I did in PyTorch LM (standard example)
traced_script_module = torch.jit.trace(model, (train_data, hidden))‘’)

Does This do the same as Tesorflow both functions do?
Can I get initial_state, rnn_state, and rnn_output?

Thank you in advance (Very Important it is)

Could you explain, what the TF methods are doing?
I would recommend to store fixed numpy inputs, load them for both frameworks (in PyTorch you would have to transform the numpy array to a tensor via torch.from_numpy), and compare their outputs.
This would make sure that both methods are indeed using the same underlying operations and parameters.

Sir the first function get_initial_state() - gives initial_state (h_0)
whereas second function single_step(self, context, word_id) - gives log_prob (output probability) rnn_states (h_(n-1):h_1) and output (h_n)

Sir correct me if I am wrong. Sir I need these in Pytorch…
So, I have stored…
traced_script_module = torch.jit.trace(model, (train_data, hidden))‘’)

I will be using it in C++

It is for Pytorch-Kaldi Language Model rescoring…

The mentioned TF functions seem to be related to some internal utility functions for RNNs, while your PyTorch code traces a model and stores it.
I’m not sure, how these approaches are related.

If your PyTorch model also implements the TF methods, then you should be probably fine.

Sir any suggestion to implement TF methods in Pytorch. Sorry Sir to disturb you again chasing a deadline “Interspeech”

Moreover Sir, We need to call all in C++. (Model + Methods Result) because Kaldi part except only C++ code for lattice rescoring.

As I’m not deeply familiar with the TF code, I assume get_initial_state might be equivalent to init_hidden and single_step seems to refer to the forward method.

If you need exactly the same results, I would still recommend to run the methods in isolation with a fixed input array and make sure you get the desired outputs.

Sir It is get_initial_state is equivalent to init_hidden and single_step is forward method

That’s why I used…

traced_script_module = torch.jit.trace(model, (train_data, hidden))‘’)

And though to use forward function in C++.

I had to write these functions because it was used in a Wrapper for extracting getting the initial state.

I have written that: Link

    def get_initial_state(self, bsz):
        #bsz -> batch_size
        #h,c = self.init_hidden(bsz)
        h=torch.zeros(self.nlayers, batchsize, self.nhid)
        c=torch.zeros(self.nlayers, batchsize, self.nhid)
        return h