Problem in making embedding layer for a CNN document classification

Hi,
I tried to make a CNN network for document classification. I used Keras previously.
So I am new to PyTorch and this indexing part is a pretty confusing part.

I consider 150 words/documents. I made my word to index dictionary and convert each word in the documents to the index.
So as a training sample I have the following tensor:

Sample input size:  torch.Size([1, 150])
Sample input: 
 tensor([[1685,  190, 5459,  727, 1295,  772, 5460,  102, 4425, 9076,  935,    7,
         1200, 9077,   25,   83,  498,  830, 2169,    7, 4426,   27,  533, 1296,
          199,  167,  433, 5461, 4427,  592,   26, 6298,   23,   34, 9078,   15,
          149, 5462, 9079,  285,  128, 6299, 1201,   15,   46,  416,  190, 9080,
          399,  139,   29, 3175,  900, 2170,  772,   54, 2880,  158,  482,   15,
          371, 5463, 9081, 3488, 1686,   26, 9082, 5464,   22,  901,  336, 1748,
         9083, 5465, 1531,  694,  134, 5466,  313, 9084, 9085, 5467,  772, 5468,
         2881,    5, 3488,   26, 5463,  371, 5469, 2695,  679, 1921,  167, 9086,
         2170,  520, 4428,  450,   72,  336, 6300,  521,   26,  695,  694, 1297,
           46, 6301,  433,  100,  337,   33,   61, 5470,  620,    6, 3176, 9087,
            2, 2326, 9088,  451,  339,  695,  935,  772, 2039, 9089,   33, 6302,
           61,   60, 2696,    2, 2327, 9090,  451,  773, 2697,   15,   83,  498,
         1531, 1114,    7,   34, 1922,  290]])

Sample label size:  torch.Size([1])
Sample label: 
 tensor([1.])

I made the following network:

class CNN(nn.Module):
    
    def __init__(self, vocab_size, output_size, embedding_dim, prob_drop):
        super(CNN, self).__init__()
        
        #Arguments"
        filter_sizes = [1,2,3]
        num_filters = 36

        self.vocab_size = vocab_size
        self.output_size = output_size
        
        self.embedding = nn.EmbeddingBag(vocab_size, embedding_dim, sparse=True)
        self.conv1 = nn.Conv2d(1,num_filters, (filter_sizes[0], embedding_dim))
        self.conv2 = nn.Conv2d(1,num_filters, (filter_sizes[1], embedding_dim))
        self.conv3 = nn.Conv2d(1,num_filters, (filter_sizes[2], embedding_dim))
#         self.conv4 = nn.Conv2d(num_filters, (fiter_sizes[3], embedding_dim))
        self.dropout = nn.Dropout(prob_drop)
    
    def conv_(self, val, conv_layer):
        
        conv_out = conv_layer(val)
      
        activation = F.relu(conv_out.squeeze(3))# activation.size() = (batch_size, out_channels, dim1)
        max_out = F.max_pool1d(activation, activation.size()[2]).squeeze(2)# maxpool_out.size() = (batch_size, out_channels)

        return max_out
    
    def forward(self, x):
        x = x.long()
       
        input_ = self.embedding(x)
    
        input_ = input_.unsqueeze(1)
      
        out1 = self.conv_(input_,self.conv1)
        
        out2 = self.conv_(input_,self.conv2)
        
     out3 = self.conv_(input_,self.conv3)
        
        all_out = torch.cat((out1, out2, out3), 1) (batch_size, num_kernels*out_channels)
        fc_in = self.dropout(all_out) 
        logits = self.label(fc_in)
        
        return logit
vocab_size = len(vocab_to_int)+1
output_size = 1

embedding_dim = 100 
prob_drop =0.1
net = CNN(vocab_size, output_size, embedding_dim, prob_drop)
lr = 0.001

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr = lr)

the training part for one sample is as follow:

net.train()
for e in range(epochs):
    

    for inputs, labels_ in train_one_loader:
        print(inputs.size())

        # zero the parameter gradients
        optimizer.zero_grad()
        
        
        outputs = net(inputs)
        
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()


However, the input size that I receive from embedding in training is [1,100] instead of [1,150,100] and it causes the error.

I am guessing that I miss one step in my training loop, but I can not figure it out.
Would you please help me to solve this problem?

but nn.Conv2d expects 4 dimensional input, it would not work for three dimensional input.

for image it is like,
[batch_size, number_of_channels, height, width]

output of nn.Embedding will be like,
[number_of_words, embedding_dimension]

we could use nn.Conv1d after nn.Embedding, that is something like,

conv = nn.Conv1d(1, 10, 3) # suppose in_channels is 1, out_channels is 10, kernel_size is 3
emb = nn.Embedding(150, 100) # 150 words, 100 embedding size
doc = torch.LongTensor([0, 1, 2]) # suppose our document has 3 words
z = emb(doc.long()) # output shape [3, 100]
e = z.reshape(3, 1, 100) # reshape it to 3 words, 1 channel, 100 embedding size, to use it with conv1d
conv(e).shape

torch.Size([3, 10, 98])

or if we use nn.Conv2d after nn.Embedding, then we will have to split embedding size 100, into 10x10,

conv = nn.Conv2d(1, 10, 3) # suppose in_channels is 1, out_channels is 10, kernel_size is 3
emb = nn.Embedding(150, 100) # 150 words, 100 embedding size
doc = torch.LongTensor([0, 1, 2]) # suppose our document has 3 words
z = emb(doc.long()) # output shape [3, 100]
e = z.reshape(3, 1, 10, 10) # reshape it to 3 words, 1 channel, 10x10 embedding size, to use it with conv2d
conv(e).shape

torch.Size([3, 10, 8, 8])

1 Like

@vainaijr, thanks for your explanation.

nn.Embeddin will receive 2 numbers. The first number is the length of the (vocabulary size +1) and not 150, which is the length of each document. The second number is the embedding dimension, which I considered as 100. I follow the steps here, from here: https://pytorch.org/tutorials/beginner/nlp/word_embeddings_tutorial.html.
What I suppose to get from the embedding layer in the training process is a tensor size of (Batch size, length of the document, embedding dimension), which I am not getting this tensor size. I am getting a tensor size of [batch size,100].
The reason that I posted this question is that for some reason, it looks like that my embedding layer couldn’t convert the indices in my sample to their corresponding embedding vector.

About the CNN part, in NLP, we usually have nnConv2d(1, number of each kind of kernel
(window size of words, word embedding dimension)). In making CNN for text, we keep the width of the kernel in the size of embedding and change the width word by word. Please visit this link:
http://www.wildml.com/2015/11/understanding-convolutional-neural-networks-for-nlp/.

Therefore, my question remains unanswered.

Any thoughts on this?

In case it helps, here’s my code that link you’ve posted. I use nn.Conv1d here, but as @vainaijr said, nn.Conv2d works as well but the dimensions have to be adjusted.

1 Like

nn.EmbeddingBag != nn.Embedding

The first will give you the average of all the embeddings in the sequence while the second will give you the embeddings in a sequence. I think you want to use the latter rather than the former.

1 Like

@vdw, Thanks for sharing the code! Your blog is one of the most comprehensive blogs in the CNN text classification that I have read so far.

Most of the examples that I found have used Conv2d but I could not see the dimension adjustments in the codes. Would you please explain why do we need to adjust the embedding dimension of 100 to 10x10?

@dhpollack, Thanks! You are right! I wanted to use nn.Embedding. Now, I’m getting the correct dimension from the embedding layer.

@sasha, WildML is not my blog. I wish I had the time and skill for that :). But I still can give it a shot at the explanation.

Once pushed a sequence through the embedding layer, you have a 2-dim tensor: (seq_len, embed_dim); we can ignore the batch size dimension here to keep it simple. Now this has the same shape as an image, so your sequence can be pushed trough a nn.Conv2d layer. However, it does not make any semantic sense to convolve over the embedding dimension.

Look up the docs and check at the constructor of nn.Conv2d:

Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros')

Where kernel_size is either a tuple of two ints reflecting the size of the kernel, e.g., (3, 5), or a single int, say 3, which implies a square kernel (3, 3). As I said, it does not make sense to convolve over the mebedding dimension. As such you have, you have to set the kernel size like:

Conv2d(in_channels, out_channels, kernel_size=(3, embed_dim), ...)

Or any other size than 3. This snippet kernel_size=(_, embed_dim) you should see in all the examples that use nn.Conv2d for text classification. If not, I would argue the model is implemented not correctly. Note that nn.Conv2d is happily using, say, kernel_size=(4, 4). It throws no error, but it’s semantically wrong in this case of text classification.

nn.Conv1d just simplifies this a bit. The constructor looks very similar:

Conv1d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros')

However, kernel_size can only be single int and not a tuple of ints since you convolve over just one dimension. As such, you would create the layer like:

Conv1d(in_channels, out_channels, kernel_size=3, ...)

which here implies the kernel size of (3, embed_dim).

In short, you can use both nn.Conv2d and nn.Conv2d. The only difference is that with nnConv2d you have to be tad more careful how you define the kernel size. With nn.Conv1d you cannot simply set the kernel size incorrectly.

I hope that helps.

1 Like

@vdw, Thanks for the clarification about the blog :slight_smile: and your comprehensive explanation! I visited the code that you shared, and it helped me a lot.

Ahaaaa got it! Now it is crystal clear to me.
There is no point in having Conv2d because, in the text, the kernel needs to move word by word while we keep the embedding dimension as a constant. And, by using Con2d, we consider a spare dimension throughout the convolution.

Thanks, @vdw again!