I am using torchtext.data.metrics.bleu_score to calculate the bleu score between hypothesis and reference corpus. Each input is an iterable, so this should work based on the example in the docs. But I get this error. Any suggestions?
import torchtext
def get_bleu_score(reference, hypothesis):
'''
Args:
- reference: target sentences (batch_size x seq_len)
- hypothesis: predicted sentences (batch_size x seq_len)
Return:
- bleu_score: corpos bleu score between reference and hypothesis
'''
# remove special tokens
new_reference = []
for sentence in reference:
new_sentence=[]
for word in sentence:
if word == 'EOS':
break
if word != 'BOS' and word != 'EOS' and word != 'PAD':
new_sentence.append(word)
new_reference.append(new_sentence)
new_reference = [new_reference]
new_hypothesis = []
for sentence in hypothesis:
new_sentence = []
for word in sentence:
if word == 'EOS':
break
if word != 'BOS' and word != 'EOS' and word != 'PAD':
new_sentence.append(word)
new_hypothesis.append(new_sentence)
print('reference: ', new_reference)
print('hypothesis: ', new_hypothesis)
# calculate BLEU score
bleu_score = torchtext.data.metrics.bleu_score(new_reference, new_hypothesis)
return bleu_score
reference = [['BOS','i','want','some','ice','cream','.','EOS','PAD']]
hypothesis = [['BOS','i','want','a','ice','cream','.','EOS','PAD']]
bleu = get_bleu_score(reference, hypothesis)
print('bleu: ', bleu)
reference: [[['i', 'want', 'some', 'ice', 'cream', '.']]]
hypothesis: [['i', 'want', 'a', 'ice', 'cream', '.']]
Traceback (most recent call last):
File "~/test_bleu.py", line 45, in <module>
bleu = get_bleu_score(reference, hypothesis)
File "~/test_bleu.py", line 38, in get_bleu_score
bleu_score = torchtext.data.metrics.bleu_score(new_reference, new_hypothesis)
File "~/myenv2/lib/python3.9/site-packages/torchtext/data/metrics.py", line 78, in bleu_score
candidate_counter = _compute_ngram_counter(candidate, max_n)
File "~/myenv2/lib/python3.9/site-packages/torchtext/data/metrics.py", line 29, in _compute_ngram_counter
ngrams_counter = collections.Counter(tuple(x.split(' '))
File "~/myenv2/lib/python3.9/collections/__init__.py", line 593, in __init__
self.update(iterable, **kwds)
File "~/myenv2/lib/python3.9/collections/__init__.py", line 679, in update
_count_elements(self, iterable)
File "~/myenv2/lib/python3.9/site-packages/torchtext/data/metrics.py", line 29, in <genexpr>
ngrams_counter = collections.Counter(tuple(x.split(' '))
AttributeError: 'list' object has no attribute 'split'