Using nn.EmbeddingBag with a CNN in PyTorch model

I am trying to build a text classifying model in PyTorch using nn.EmbeddingBag and a CNN.

I know there are a bunch of NLP CNN models using nn.Embedding out there but due to my hardware constraints I do not want to use nn.Embedding.

I have this simple model setup for starters (16 is the batch size):

class CNN(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_class):
        super().__init__()
        self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=True)
        self.conv1 = nn.Sequential(
            nn.Conv1d(in_channels=embed_dim, out_channels=16, kernel_size=1),
            nn.ReLU(),
            nn.MaxPool1d(1)
        )
        self.fc = nn.Linear(16, num_class)

    def forward(self, text, offsets):
        x = self.embedding(text, offsets).unsqueeze(2)
        x = self.conv1(x)
        x = x.view(-1, 16)
        x = self.fc(x)
        return x

This CNN model “compiles” (the dimensions and everything matches) but training does not accuracy stays at ~7% (random guessing). In other words, this model doesn’t work.

I should also add that the reason I’m trying to do this is to improve the following fully connected model (which works fine):

class FC(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_class):
        super().__init__()
        self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=True)
        self.fc = nn.Linear(embed_dim, num_class)

    def forward(self, text, offsets):
        embedded = self.embedding(text, offsets)
        return self.fc(embedded)

How can I implement a CNN model using the nn.EmbeddingBag?

If that’s not possible, how can I improved the FC model?

Thanks!

Based on your code snippet it seems your temporal dimension has only a single value which would also explain the kernel size of 1 in the conv layer. In that case, this operation should be equal to a linear layer. Also the pooling layer won’t have any effect with a kernel size of 1 as seen here:

x = torch.randn(2, 16, 24)
pool = nn.MaxPool1d(1)
out = pool(x)
print((x == out).all())
> tensor(True)

so you can remove it.