Why is there a `torch.cat` the in the collate function for a text dataloader?

I am looking at a pytorch tutorial for simple text classifier. The tutorial is simple enough, but what is confusing is the collate function when creating batches of data for training. I am not clear on why the tutorial uses torch.cat to combine the text lists.

So the code below will create a dataloader with the text label, and then the numericalized sentence itself. My question is, why do the authors of the tutorial use torch.cat() on the text_list in the batch. It seems like concatenating them will muddy or confuse the difference between each text. Like will the data go from being a batch of 8 texts, to a batch of just 1 big text?

from torchtext.datasets import AG_NEWS
from torchtext.data.utils import get_tokenizer
from torch.utils.data import DataLoader
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

tokenizer = get_tokenizer('basic_english')
text_pipeline = lambda x: vocab(tokenizer(x))
label_pipeline = lambda x: int(x) - 1

def collate_batch(batch):
    label_list, text_list, offsets = [], [], [0]
    for (_label, _text) in batch:
         processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
    label_list = torch.tensor(label_list, dtype=torch.int64)
    offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
    text_list = torch.cat(text_list) ##     <-------------  QUESTION IS WHAT USE CAT HERE?
    return label_list.to(device), text_list.to(device), offsets.to(device)

train_iter = AG_NEWS(split='train')
dataloader = DataLoader(train_iter, batch_size=8, shuffle=False, collate_fn=collate_batch)

I am asking this, because when I look at some sample data using next(iter(dataloader)), even I cannot tell where one data example ends and the next one begins. I forced the sequence length of my texts to be just 2, so that I can see the output better. There are 8 examples of length 2. Just by looking at the output it seems like there is just 1 vector that is 16 elements, instead of 8 vectors that are 2 elements in size. That could just be a limitation of the printing of tensors, but I wanted to make sure I was not doing anything weird.


(tensor([ 1059,   454,   431,   425,    58,     8,    78,   798, 38487,   410,
          202,  1643,   272,  1197, 16858,    30]),
 tensor([3, 3, 3, 3, 3, 3, 3, 3]))

Hence, can anyone explain why the use of torch.cat here? Like is it necessary, or are there other ways to do this?

Your observation seems to be right and is also described in the linked tutorial:

In this example, the text entries in the original data batch input are packed into a list and concatenated as a single tensor for the input of nn.EmbeddingBag. The offset is a tensor of delimiters to represent the beginning index of the individual sequence in the text tensor. Label is a tensor saving the labels of individual text entries.
Although the text entries here have different lengths, nn.EmbeddingBag module requires no padding here since the text lengths are saved in offsets.

where the offsets are used in:

class TextClassificationModel(nn.Module):

    def __init__(self, vocab_size, embed_dim, num_class):
        super(TextClassificationModel, self).__init__()
        self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=True)
        self.fc = nn.Linear(embed_dim, num_class)

    def init_weights(self):
        initrange = 0.5
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.fc.weight.data.uniform_(-initrange, initrange)

    def forward(self, text, offsets):
        embedded = self.embedding(text, offsets)
        return self.fc(embedded)

@ptrblck Thanks for your comment and validating my intuition here. I had not really seen this EmbeddingBag layer before, but I can see now that it actually combines the data batch into a long vector and then uses the offsets to identify the boundaries of each example. Haha, it is not a structure that I was accustomed to, since the regular Embedding layer does not work that way.

Does this EmbeddingBag layer work better than standard embedding layers? Seems like I can still use existing Glove embeddings and other downloaded embeddings with the EmbeddingBag too.

This blog post explains the difference between these layers. I don’t know which layer would work better for which use cases, so you should definitely try out both :wink:

1 Like