can you open an issue on torchtest github with a code snippet. The translation dataset has not been re-written and there may be some bugs there. But we can try to have some people take a look at it.
spacy_de = spacy.load('de')
spacy_en = spacy.load('en')
def tokenize_de(text):
"""
Tokenizes German text from a string into a list of strings (tokens) and reverses it
"""
return [tok.text for tok in spacy_de.tokenizer(text)][::-1]
def tokenize_en(text):
"""
Tokenizes English text from a string into a list of strings (tokens)
"""
return [tok.text for tok in spacy_en.tokenizer(text)]
SRC = Field(tokenize=tokenize_de,
init_token=init_token,
eos_token=eos_token,
lower=True,
batch_first=True)
TRG = Field(tokenize=tokenize_en,
init_token=init_token,
eos_token=eos_token,
lower=True,
batch_first=True)
train_data, valid_data = IWSLT.splits(exts=('.de', '.en'),
fields=(SRC, TRG),
test=None,
filter_pred=lambda x: len(vars(x)['src']) <= max_seq_len and
len(vars(x)['trg']) <= max_seq_len)
SRC.build_vocab(train_data, min_freq=min_freq)
TRG.build_vocab(train_data, min_freq=min_freq)
train_iter, valid_iter = BucketIterator.splits(
(train_data, valid_data),
batch_size=batch_size,
device=device)
As i say at the beginning, I think the last element of src and trg should be <eos> token, which is 3, but in fact the last element of src and trg is 1, that is <pad> token.