I am using the following code for implementing beam search for text generation.
def beam_search_decoder(data, k):
sequences = [[list(), 0.0]]
# walk over each step in sequence
for row in data:
all_candidates = list()
# expand each current candidate
for i in range(len(sequences)):
seq, score = sequences[i]
for j in range(len(row)):
candidate = [seq + [j], score - torch.log(row[j])]
all_candidates.append(candidate)
# order all candidates by score
ordered = sorted(all_candidates, key=lambda tup:tup[1])
# select k best
sequences = ordered[:k]
return sequences
This performs beam search decoding for batch_size 1. I am working with a vocab_size of ~9900, so this function itself is extremely slow. Is there an alternative faster way to do batch-wise beam search ?