I am trying to implement a Siamese network dedicated to detecting similarities between sentences.
The input to the dataset is three lists (anchor, positive, and negative).
My custom dataset is implemented in the following way:
max_seq_length = 16
class DescriptionDataset(data.Dataset):
def __init__(self,positive_list, negative_list, anchor_list):
self.negative_list = negative_list
self.positive_list = positive_list
self.anchor_list = anchor_list
def __getitem__(self,index):
return (self.to_tokens(self.anchor_list[index]),
self.to_tokens(self.positive_list[index]),
self.to_tokens(self.negative_list[index])
)
def _to_tokens(description):
tokenized_description = tokenizer.tokenize(description)
if len(tokenized_description) > max_seq_length:
tokenized_description = tokenized_description[:max_seq_length]
ids_description = tokenizer.convert_tokens_to_ids(tokenized_description)
padding = [0] * (max_seq_length - len(ids_description))
ids_description += padding
ids_description = torch.tensor(ids_description)
def __len__(self):
return len(self.anchor_list)
My network is implemented in the following way:
def __init__(self, encoder, encoder_dim, n_hidden, use_gpu=True):
super(SiameseNetwork, self).__init__()
self.use_gpu=use_gpu
self.encoder = encoder
self.n_hidden = n_hidden
self.fc = nn.Linear(encoder_dim, hidden_dim,bias=False)
def forward_embedding(self, x):
x = self.encoder(x)
self.hidden = cache
output = self.fc(output)
return output
def forward(self, x, y, z):
embedded_x = self.forward_embedding(x)
embedded_y = self.forward_embedding(y)
embedded_z = self.forward_embedding(z)
dist_a = F.pairwise_distance(embedded_x, embedded_y, 2)
dist_b = F.pairwise_distance(embedded_x, embedded_z, 2)
return dist_a, dist_b, embedded_x, embedded_y, embedded_z
As an encoder, I am going to use the pre-trained BERT model.
The problem is that I have not idea how to implement custom DataLoader, in particular how to overload the default collate_fn function.
Can you help me?