Siamese network data loader

I am trying to implement a Siamese network dedicated to detecting similarities between sentences.
The input to the dataset is three lists (anchor, positive, and negative).
My custom dataset is implemented in the following way:

max_seq_length = 16

class DescriptionDataset(data.Dataset):
    def __init__(self,positive_list, negative_list, anchor_list):
        self.negative_list = negative_list
        self.positive_list = positive_list
        self.anchor_list = anchor_list
    def __getitem__(self,index):
        return (self.to_tokens(self.anchor_list[index]),
    def _to_tokens(description):
        tokenized_description = tokenizer.tokenize(description)
        if len(tokenized_description) > max_seq_length:
            tokenized_description = tokenized_description[:max_seq_length]
        ids_description  = tokenizer.convert_tokens_to_ids(tokenized_description)
        padding = [0] * (max_seq_length - len(ids_description))
        ids_description += padding
        ids_description = torch.tensor(ids_description)
    def __len__(self):
        return len(self.anchor_list)

My network is implemented in the following way:

   def __init__(self, encoder, encoder_dim, n_hidden, use_gpu=True):
        super(SiameseNetwork, self).__init__()
        self.encoder = encoder
        self.n_hidden = n_hidden
        self.fc = nn.Linear(encoder_dim, hidden_dim,bias=False)
   def forward_embedding(self, x):
        x = self.encoder(x)
        self.hidden = cache
        output = self.fc(output)
        return output
   def forward(self, x, y, z):
        embedded_x = self.forward_embedding(x)
        embedded_y = self.forward_embedding(y)
        embedded_z = self.forward_embedding(z)
        dist_a = F.pairwise_distance(embedded_x, embedded_y, 2)
        dist_b = F.pairwise_distance(embedded_x, embedded_z, 2)
        return dist_a, dist_b, embedded_x, embedded_y, embedded_z

As an encoder, I am going to use the pre-trained BERT model.
The problem is that I have not idea how to implement custom DataLoader, in particular how to overload the default collate_fn function.
Can you help me?