Hi,
I know there are a many questions about the collate_fn()
function to make inputs the same shape but I still have problems to understand it and use it.
Because all my text is longer than 512 tokens I need to cut them into smaller pieces. So i applied a sliding window in the __getitem__
function. Here is my Dataset
class.
MAX_LEN = 400
STRIDE = 20
class CustomDataset(Dataset):
def __init__(self, dataframe, tokenizer, max_len, stride):
self.tokenizer = tokenizer
self.data = dataframe
self.text = dataframe.text
self.targets = self.data.labels
self.max_len = max_len
self.stride = stride
def __len__(self):
return len(self.text)
def __getitem__(self, index):
text = str(self.text[index])
text = " ".join(text.split())
inputs = self.tokenizer(
text,
None,
max_length=MAX_LEN,
stride=STRIDE,
padding='max_length',
truncation='only_first',
return_overflowing_tokens=True,
return_tensors='pt'
)
ids = inputs['input_ids']
mask = inputs['attention_mask']
token_type_ids = inputs["token_type_ids"]
return {
'ids': torch.tensor(ids, dtype=torch.long),
'mask': torch.tensor(mask, dtype=torch.long),
'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long),
'targets': torch.tensor(self.targets[index], dtype=torch.float)
}
As you can see I return a dict with ids
, attention_mask
, token_type_ids
and the target
.
For example if I use batch_size = 8
this could be the potential data
|ID´s: torch.Size([971, 400]) | Mask: torch.Size([971, 400]) | TokenTypeID´s: torch.Size([971, 400]) | Targets: torch.Size([6])|
|ID´s: torch.Size([17792, 400]) | Mask: torch.Size([17792, 400]) | TokenTypeID´s: torch.Size([17792, 400]) | Targets: torch.Size([6])|
|ID´s: torch.Size([177, 400]) | Mask: torch.Size([177, 400]) | TokenTypeID´s: torch.Size([177, 400]) | Targets: torch.Size([6])|
|ID´s: torch.Size([402, 400]) | Mask: torch.Size([402, 400]) | TokenTypeID´s: torch.Size([402, 400]) | Targets: torch.Size([6])|
|ID´s: torch.Size([11, 400]) | Mask: torch.Size([11, 400]) | TokenTypeID´s: torch.Size([11, 400]) | Targets: torch.Size([6])|
|ID´s: torch.Size([48, 400]) | Mask: torch.Size([48, 400]) | TokenTypeID´s: torch.Size([48, 400]) | Targets: torch.Size([6])|
|ID´s: torch.Size([7, 400]) | Mask: torch.Size([7, 400]) | TokenTypeID´s: torch.Size([7, 400]) | Targets: torch.Size([6])|
|ID´s: torch.Size([31, 400]) | Mask: torch.Size([31, 400]) | TokenTypeID´s: torch.Size([31, 400]) | Targets: torch.Size([6])|
Because the shapes are not the same I can´t load them into the network. If I understand it right, I can use the collate_fn()
function to bring all data into the same shape.
I wrote a custom function as desribed here
def pad_collate(batch):
data = [item['ids'] for item in batch]
data = pack_sequence(data, enforce_sorted=False)
targets = [item['targets'] for item in batch]
return data, targets
which returns something like this:
(PackedSequence(data=tensor([[ 3, 2843, 6406, ..., 10437, 59, 4],
[ 3, 10696, 26897, ..., 1464, 248, 4],
[ 3, 3396, 7083, ..., 26971, 26924, 4],
...,
[ 3, 115, 1225, ..., 26935, 15070, 4],
[ 3, 62, 26914, ..., 21, 16923, 4],
[ 3, 26900, 860, ..., 0, 0, 0]]), batch_sizes=tensor([2, 2, 2, ..., 1, 1, 1]), sorted_indices=tensor([1, 0]), unsorted_indices=tensor([1, 0])), [tensor([1., 0., 0., 0., 0., 0.]), tensor([0., 0., 0., 1., 0., 0.])])
I can´t understand the output, can you help me with that?
Also this only pads the text to the longest sequence, but what´s about the mask
and token_type_ids
?
How can I extend the pad_collate()
function with mask
and token_type_ids
so that I can feed the data into network and train it like that:
def train(epoch):
model.train()
for _, data in enumerate(training_loader, 0):
ids = data['ids'].to(device, dtype=torch.long)
mask = data['mask'].to(device, dtype=torch.long)
token_type_ids = data['token_type_ids'].to(device, dtype=torch.long)
targets = data['targets'].to(device, dtype=torch.float)
outputs = model(ids, mask, token_type_ids)
optimizer.zero_grad()
loss = loss_fn(outputs, targets)
if _ % 5000 == 0:
print(f'Epoch: {epoch}, Loss: {loss.item()}')
optimizer.zero_grad()
loss.backward()
optimizer.step()
for epoch in range(EPOCHS):
train(epoch)