I am looking at a pytorch tutorial for simple text classifier. The tutorial is simple enough, but what is confusing is the collate function when creating batches of data for training. I am not clear on why the tutorial uses torch.cat
to combine the text lists.
So the code below will create a dataloader with the text label, and then the numericalized sentence itself. My question is, why do the authors of the tutorial use torch.cat()
on the text_list
in the batch. It seems like concatenating them will muddy or confuse the difference between each text. Like will the data go from being a batch of 8 texts, to a batch of just 1 big text?
from torchtext.datasets import AG_NEWS
from torchtext.data.utils import get_tokenizer
from torch.utils.data import DataLoader
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = get_tokenizer('basic_english')
text_pipeline = lambda x: vocab(tokenizer(x))
label_pipeline = lambda x: int(x) - 1
def collate_batch(batch):
label_list, text_list, offsets = [], [], [0]
for (_label, _text) in batch:
label_list.append(label_pipeline(_label))
processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
text_list.append(processed_text)
offsets.append(processed_text.size(0))
label_list = torch.tensor(label_list, dtype=torch.int64)
offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
text_list = torch.cat(text_list) ## <------------- QUESTION IS WHAT USE CAT HERE?
return label_list.to(device), text_list.to(device), offsets.to(device)
train_iter = AG_NEWS(split='train')
dataloader = DataLoader(train_iter, batch_size=8, shuffle=False, collate_fn=collate_batch)
I am asking this, because when I look at some sample data using next(iter(dataloader))
, even I cannot tell where one data example ends and the next one begins. I forced the sequence length of my texts to be just 2, so that I can see the output better. There are 8 examples of length 2. Just by looking at the output it seems like there is just 1 vector that is 16 elements, instead of 8 vectors that are 2 elements in size. That could just be a limitation of the printing of tensors, but I wanted to make sure I was not doing anything weird.
next(iter(dataloader))
(tensor([ 1059, 454, 431, 425, 58, 8, 78, 798, 38487, 410,
202, 1643, 272, 1197, 16858, 30]),
tensor([3, 3, 3, 3, 3, 3, 3, 3]))
Hence, can anyone explain why the use of torch.cat
here? Like is it necessary, or are there other ways to do this?