I am using the following code for implementing beam search for text generation.
def beam_search_decoder(data, k): sequences = [[list(), 0.0]] # walk over each step in sequence for row in data: all_candidates = list() # expand each current candidate for i in range(len(sequences)): seq, score = sequences[i] for j in range(len(row)): candidate = [seq + [j], score - torch.log(row[j])] all_candidates.append(candidate) # order all candidates by score ordered = sorted(all_candidates, key=lambda tup:tup) # select k best sequences = ordered[:k] return sequences
This performs beam search decoding for batch_size 1. I am working with a vocab_size of ~9900, so this function itself is extremely slow. Is there an alternative faster way to do batch-wise beam search ?