Collate_fn function to handle different text size

Hi,

I know there are a many questions about the collate_fn() function to make inputs the same shape but I still have problems to understand it and use it.

Because all my text is longer than 512 tokens I need to cut them into smaller pieces. So i applied a sliding window in the __getitem__ function. Here is my Dataset class.

MAX_LEN = 400
STRIDE = 20

class CustomDataset(Dataset):

    def __init__(self, dataframe, tokenizer, max_len, stride):
        self.tokenizer = tokenizer
        self.data = dataframe
        self.text = dataframe.text
        self.targets = self.data.labels
        self.max_len = max_len
        self.stride = stride

    def __len__(self):
        return len(self.text)

    def __getitem__(self, index):
        text = str(self.text[index])
        text = " ".join(text.split())

        inputs = self.tokenizer(
            text,
            None,
            max_length=MAX_LEN,
            stride=STRIDE,
            padding='max_length',
            truncation='only_first',
            return_overflowing_tokens=True,
            return_tensors='pt'
        )
        ids = inputs['input_ids']
        mask = inputs['attention_mask']
        token_type_ids = inputs["token_type_ids"]

        return {

            'ids': torch.tensor(ids, dtype=torch.long),
            'mask': torch.tensor(mask, dtype=torch.long),
            'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long),
            'targets': torch.tensor(self.targets[index], dtype=torch.float)
        }

As you can see I return a dict with ids, attention_mask, token_type_ids and the target.
For example if I use batch_size = 8 this could be the potential data

|ID´s: torch.Size([971, 400]) | Mask: torch.Size([971, 400]) | TokenTypeID´s: torch.Size([971, 400]) | Targets: torch.Size([6])|
|ID´s: torch.Size([17792, 400]) | Mask: torch.Size([17792, 400]) | TokenTypeID´s: torch.Size([17792, 400]) | Targets: torch.Size([6])|
|ID´s: torch.Size([177, 400]) | Mask: torch.Size([177, 400]) | TokenTypeID´s: torch.Size([177, 400]) | Targets: torch.Size([6])|
|ID´s: torch.Size([402, 400]) | Mask: torch.Size([402, 400]) | TokenTypeID´s: torch.Size([402, 400]) | Targets: torch.Size([6])|
|ID´s: torch.Size([11, 400]) | Mask: torch.Size([11, 400]) | TokenTypeID´s: torch.Size([11, 400]) | Targets: torch.Size([6])|
|ID´s: torch.Size([48, 400]) | Mask: torch.Size([48, 400]) | TokenTypeID´s: torch.Size([48, 400]) | Targets: torch.Size([6])|
|ID´s: torch.Size([7, 400]) | Mask: torch.Size([7, 400]) | TokenTypeID´s: torch.Size([7, 400]) | Targets: torch.Size([6])|
|ID´s: torch.Size([31, 400]) | Mask: torch.Size([31, 400]) | TokenTypeID´s: torch.Size([31, 400]) | Targets: torch.Size([6])|

Because the shapes are not the same I can´t load them into the network. If I understand it right, I can use the collate_fn() function to bring all data into the same shape.

I wrote a custom function as desribed here

def pad_collate(batch):
    data = [item['ids'] for item in batch]
    data = pack_sequence(data, enforce_sorted=False)
    targets = [item['targets'] for item in batch]
    return data, targets

which returns something like this:

(PackedSequence(data=tensor([[    3,  2843,  6406,  ..., 10437,    59,     4],
        [    3, 10696, 26897,  ...,  1464,   248,     4],
        [    3,  3396,  7083,  ..., 26971, 26924,     4],
        ...,
        [    3,   115,  1225,  ..., 26935, 15070,     4],
        [    3,    62, 26914,  ...,    21, 16923,     4],
        [    3, 26900,   860,  ...,     0,     0,     0]]), batch_sizes=tensor([2, 2, 2,  ..., 1, 1, 1]), sorted_indices=tensor([1, 0]), unsorted_indices=tensor([1, 0])), [tensor([1., 0., 0., 0., 0., 0.]), tensor([0., 0., 0., 1., 0., 0.])])

I can´t understand the output, can you help me with that?
Also this only pads the text to the longest sequence, but what´s about the mask and token_type_ids?
How can I extend the pad_collate() function with mask and token_type_ids so that I can feed the data into network and train it like that:

def train(epoch):
    model.train()
    for _, data in enumerate(training_loader, 0):
        ids = data['ids'].to(device, dtype=torch.long)
        mask = data['mask'].to(device, dtype=torch.long)
        token_type_ids = data['token_type_ids'].to(device, dtype=torch.long)
        targets = data['targets'].to(device, dtype=torch.float)

        outputs = model(ids, mask, token_type_ids)

        optimizer.zero_grad()
        loss = loss_fn(outputs, targets)
        if _ % 5000 == 0:
            print(f'Epoch: {epoch}, Loss:  {loss.item()}')

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

for epoch in range(EPOCHS):
    train(epoch)