Ways to improve Pytorch's performance during inference using small model and small batches?

Hi. I’m training a simple model (seq2seq RNN) and using it in inference mode to generate new batches for training from a stream of data using beam search. Beam search code uses a lot of slicing, indexing, conditional statements - everything there is being processed via Pytorch’s tensors and functions. Essentially, to generate new training batch it needs to run one RNN step multiple times for quite small batches. I’ve already wrapped it in torch.no_grad(), but what else can I do to improve performance?