Hi! I have this code
import re
from datasets import load_dataset
import pandas as pd
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from transformers import (
AdamW,
T5ForConditionalGeneration,
T5Tokenizer,
get_linear_schedule_with_warmup
)
class TranslateDataset(Dataset):
def __init__(
self,
data_path: str,
tokenizer: T5Tokenizer,
max_seq_len: int = 20,
max_target_len: int=20,
memory_len: int =1,
type: str = "train",
):
self.type = type
self.base_folder = data_path
if type =='train':
self.data = load_dataset('text',
data_files={'train_en': self.base_folder + 'en_train',
'train_ru': self.base_folder + 'ru_train'}
)
elif type =="dev":
self.data = load_dataset('text',
data_files={'dev_en': self.base_folder + 'en_dev',
'dev_ru': self.base_folder + 'ru_dev'}
)
elif type =="test":
self.data = load_dataset('text',
data_files={
'test_en': self.base_folder + 'en_test',
'test_ru': self.base_folder + 'ru_test'}
)
else:
raise ValueError("Invalid type dataset : %s /n data set should be one of the following {'train', 'dev', 'test' }" % type)
self.tokenizer = tokenizer
self.max_seq_len = max_seq_len
self.max_target_len = max_target_len
self.memory_len = memory_len
print("data: ", self.data)
def __len__(self):
return len(self.data[self.type+'_en'])
def __getitem__(self, index: int):
source_row = self.data[self.type+'_en'][index]['text']
target_row = self.data[self.type+'_ru'][index]['text']
# process source
source_row = source_row.split('_eos')
num_seq = len(source_row)
source_ids = []
source_masks = []
for i, sub_sentence in enumerate(source_row):
source_encoding = self.tokenizer(
sub_sentence,
max_length=self.max_seq_len,
padding='max_length',
truncation=True,
return_attention_mask=True,
add_special_tokens=True,
return_tensors="pt"
)
source_ids.append(source_encoding["input_ids"].flatten())
source_masks.append(source_encoding["attention_mask"].flatten())
source_ids = torch.cat(source_ids)
source_masks = torch.cat(source_masks)
# process target
targets_ids = []
targets_masks = []
target_row = target_row.split('_eos')
# print("target_row: ", target_row)
for i, sub_sentence in enumerate(target_row):
target_encoding = self.tokenizer(
sub_sentence, #
max_length=self.max_seq_len, # 20
padding='max_length',
truncation=True,
return_attention_mask=True,
add_special_tokens=True,
return_tensors="pt"
)
labels = target_encoding['input_ids']
labels[labels == 0] = -100
targets_ids.append(labels.flatten())
targets_masks.append(target_encoding["attention_mask"].flatten())
targets_ids = torch.cat(targets_ids)
targets_masks = torch.cat(targets_masks)
return dict(
num_seq = 4,
)
tokenizer = T5Tokenizer.from_pretrained('t5-small')
dataset = TranslateDataset(data_path = base_folder,
tokenizer=tokenizer,
max_seq_len=12,
max_target_len=12,
memory_len =1,
)
dataloader = DataLoader(
dataset,
batch_size=2,
num_workers=2,
)
next(iter(dataloader))
and this is the result:
data: DatasetDict({
train_en: Dataset({
features: [‘text’],
num_rows: 1500000
})
train_ru: Dataset({
features: [‘text’],
num_rows: 1500000
})
}){‘num_seq’: tensor([4, 4])}
I suspect the reason is the implementation of dataloder or collate function , what is the solution?