I am looking at a pytorch tutorial for simple text classifier. The tutorial is simple enough, but what is confusing is the collate function when creating batches of data for training. I am not clear on why the tutorial uses
torch.cat to combine the text lists.
So the code below will create a dataloader with the text label, and then the numericalized sentence itself. My question is, why do the authors of the tutorial use
torch.cat() on the
text_list in the batch. It seems like concatenating them will muddy or confuse the difference between each text. Like will the data go from being a batch of 8 texts, to a batch of just 1 big text?
from torchtext.datasets import AG_NEWS from torchtext.data.utils import get_tokenizer from torch.utils.data import DataLoader device = torch.device("cuda" if torch.cuda.is_available() else "cpu") tokenizer = get_tokenizer('basic_english') text_pipeline = lambda x: vocab(tokenizer(x)) label_pipeline = lambda x: int(x) - 1 def collate_batch(batch): label_list, text_list, offsets = , ,  for (_label, _text) in batch: label_list.append(label_pipeline(_label)) processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64) text_list.append(processed_text) offsets.append(processed_text.size(0)) label_list = torch.tensor(label_list, dtype=torch.int64) offsets = torch.tensor(offsets[:-1]).cumsum(dim=0) text_list = torch.cat(text_list) ## <------------- QUESTION IS WHAT USE CAT HERE? return label_list.to(device), text_list.to(device), offsets.to(device) train_iter = AG_NEWS(split='train') dataloader = DataLoader(train_iter, batch_size=8, shuffle=False, collate_fn=collate_batch)
I am asking this, because when I look at some sample data using
next(iter(dataloader)), even I cannot tell where one data example ends and the next one begins. I forced the sequence length of my texts to be just 2, so that I can see the output better. There are 8 examples of length 2. Just by looking at the output it seems like there is just 1 vector that is 16 elements, instead of 8 vectors that are 2 elements in size. That could just be a limitation of the printing of tensors, but I wanted to make sure I was not doing anything weird.
next(iter(dataloader)) (tensor([ 1059, 454, 431, 425, 58, 8, 78, 798, 38487, 410, 202, 1643, 272, 1197, 16858, 30]), tensor([3, 3, 3, 3, 3, 3, 3, 3]))
Hence, can anyone explain why the use of
torch.cat here? Like is it necessary, or are there other ways to do this?