Batched inference of pre-trained Language Model

Hello everyone,
I have trained a Language Model using the word language model codebase and now I’d like to use this trained model only for inference. So, I looked into generate.py which does this job. But, the problem is that it accepts a single as the starting index and then generates words based on the number of words we ask it to generate.

My use case is little different. I will get a bunch of indices which come from another language model. And I want to use these indices to do batched inference on my pre-trained language model.

Question: Is there any way to do batched inference and get the hidden states from the pre-trained language model?

My naive solution is to loop over all indices and then pass each of the indices to the pre-trained language model to get the corresponding hidden states. However, this is super slow because I have to do such inference over 1000’s of times again and again.