However, I’m altering the code as I go to be a bit more “modern” and extensible as well. So I’ve also implemented a custom dataset that in combination with a DataLoader is giving me the error above.
Here’s what I think is happening:
The dataset is a bunch of names and their ethnic origins. In __getitem__ the names are one-hot encoded at the character level. A tensor with shape (len(name), len(character_vocab) is returned.
The dataloader doesn’t like that the names are variable length and freaks out if batch_size > 1
Question: What technique could I use to overcome this?
What I understand is that the problem is that dataloader expects inputs of equal lengths.
So, to fix that you need a custom collate_fn.
What I mean is that if you will look into docs of data.DataLoader it mentions
A custom collate_fn can be used to customize collation, e.g., padding sequential data to max length of a batch. See this section on more about collate_fn .
I would suggest you change the collate function, for your case this should work!
import torch
from torch.utils.data import DataLoader
## trying to mimic your data with the input in the error
temp_data = [[[0 for _ in range(57)] for _ in range(8)], [[0 for _ in range(57)] for _ in range(4)]]
def custom_padding_collate(batch):
"""
This method takes list of data as input of varying size and returns the batch based on max length on that batch
Args:
batch: input to the dataloader of batch_size argument
Returns:
x (torch.tensor): tensor of data of shape (batch, max_length_in_batch, embedding_dim)
x_lengths (torch.tensor): one dimensional tensor with lengths of each element in a batch
"""
batch_size = len(batch)
x_lengths = [len(t) for t in batch]
T = max(x_lengths)
chacater_vocab_len = len(batch[0][0])
for index in range(batch_size):
batch[index] = batch[index] + [[0.0] * chacater_vocab_len] * (T - len(batch[index]))
batch[index] = torch.tensor(batch[index])
x = torch.stack(batch)
x_lengths = torch.tensor(x_lengths)
return x, x_lengths
dl = DataLoader(temp_data, batch_size=2, collate_fn=custom_padding_collate)
for x, t in dl:
print(f"Data: {x}, Lenghts: {t}")
# This would give you your data and lengths