Batch wise beam search decoding

I am using the following code for implementing beam search for text generation.

def beam_search_decoder(data, k):
	sequences = [[list(), 0.0]]
	# walk over each step in sequence
	for row in data:
		all_candidates = list()
		# expand each current candidate
		for i in range(len(sequences)):
			seq, score = sequences[i]
			for j in range(len(row)):
				candidate = [seq + [j], score - torch.log(row[j])]
				all_candidates.append(candidate)
		# order all candidates by score
		ordered = sorted(all_candidates, key=lambda tup:tup[1])
		# select k best
		sequences = ordered[:k]
	return sequences

This performs beam search decoding for batch_size 1. I am working with a vocab_size of ~9900, so this function itself is extremely slow. Is there an alternative faster way to do batch-wise beam search ?

2 Likes