I am developping a QG (Question Generation) application. Where my data is on the following format:
[{"context":"Some Context",
"question":["question1", "question2", "question3"]},
{"context":"Some other Context",
"question":["question1"]},
{"context":" Another Context",
"question":["question1", "question2", "question3","question4"]},
...
]
In a fist moment I have created dataframe replicating the context for every question. So I have created this Dataset that works well, but dont allows me to correctly track my metrics (BLEU,ROUGE,…).
from typing import List, Union,Optional
class CustomDataset(Dataset):
def __init__(self,PREFIX,tokenizer,X_context:np.ndarray,y_question:Optional[np.ndarray]=[],
source_max_length: int = 32, target_max_length: int = 32):
self.tokenizer = tokenizer
self.X_context = X_context
self.y_question = y_question
self.source_max_length = source_max_length + len(PREFIX.split(' '))
self.target_max_length = target_max_length
self.PREFIX = PREFIX
def __len__(self):
return len(self.X_context)
def __getitem__(self, idx):
#Source
original_source = self.X_context[idx]
source = f"{PREFIX} {original_source}"
source_encoder = self.encoder_plus(source,self.source_max_length)
source_token_ids = source_encoder['input_ids']
source_mask = source_encoder['attention_mask']
source_token_ids = torch.tensor(source_token_ids).type(torch.long)
source_mask = torch.tensor(source_mask).type(torch.long)
# Target
original_target = self.y_question[idx]
target = f"{original_target}"
target_encoder = self.encoder_plus(target,self.target_max_length)
target_token_ids = target_encoder['input_ids']
target_mask = target_encoder['attention_mask']
target_token_ids = torch.tensor(target_token_ids).type(torch.long)
target_mask = torch.tensor(target_mask).type(torch.long)
return (source_token_ids, source_mask, target_token_ids, target_mask,
original_source, original_target)
def encoder_plus(self,text,L):
return self.tokenizer.encode_plus(text,
max_length = L,
truncation=True,
padding="max_length")
But now I want to have all the questions that are correlated with the same context on the same batch to be able to track all metrics on my model. The problem is when I generate a question, I they dont are on the same batch I cant know to wich question compare. I image that possibily i could use a collate function, but cant figure out how.