Simplest collate for returning a list in each batch

If I want to include a list of sentences in a data set item, like this:

class Dataset(torch.utils.data.Dataset):
	def __init__(self):
		self.items = [
			{'number': 1, 'sents': ['sent1_1', 'sent1_2']},
			{'number': 2, 'sents': ['sent2_1', 'sent2_2']},
			{'number': 3, 'sents': ['sent3_1', 'sent3_2']},
		]
	def __len__(self):
		return len(self.items)
	def __getitem__(self, idx):
		return self.items[idx]

And then use a data loader with a default collate function, like this:

next(iter(torch.utils.data.DataLoader(Dataset(), batch_size=2)))

The default collate function will group the sentences across items like this:

{'number': tensor([1, 2]), 'sents': [('sent1_1', 'sent2_1'), ('sent1_2', 'sent2_2')]}

When what I want is for the lists of sentences within each item to be kept together like this:

{'number': tensor([1, 2]), 'sents': [('sent1_1', 'sent1_2'), ('sent2_1', 'sent2_2')]}

What’s the simplest collate function that will do this for me?

It seems like this is a python related question and not so much a torch related question, since you are working with lists, dictionaries and strings. Getting the numbers in a torch tensor however can be done like this:

torch.tensor([x.number for x in self.items])

Hope this helps a bit.