@sasha, WildML is not my blog. I wish I had the time and skill for that :). But I still can give it a shot at the explanation.
Once pushed a sequence through the embedding layer, you have a 2-dim tensor: (seq_len, embed_dim); we can ignore the batch size dimension here to keep it simple. Now this has the same shape as an image, so your sequence can be pushed trough a nn.Conv2d
layer. However, it does not make any semantic sense to convolve over the embedding dimension.
Look up the docs and check at the constructor of nn.Conv2d
:
Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros')
Where kernel_size
is either a tuple of two ints reflecting the size of the kernel, e.g., (3, 5), or a single int, say 3, which implies a square kernel (3, 3). As I said, it does not make sense to convolve over the mebedding dimension. As such you have, you have to set the kernel size like:
Conv2d(in_channels, out_channels, kernel_size=(3, embed_dim), ...)
Or any other size than 3. This snippet kernel_size=(_, embed_dim)
you should see in all the examples that use nn.Conv2d
for text classification. If not, I would argue the model is implemented not correctly. Note that nn.Conv2d
is happily using, say, kernel_size=(4, 4)
. It throws no error, but it’s semantically wrong in this case of text classification.
nn.Conv1d
just simplifies this a bit. The constructor looks very similar:
Conv1d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros')
However, kernel_size
can only be single int and not a tuple of ints since you convolve over just one dimension. As such, you would create the layer like:
Conv1d(in_channels, out_channels, kernel_size=3, ...)
which here implies the kernel size of (3, embed_dim)
.
In short, you can use both nn.Conv2d
and nn.Conv2d
. The only difference is that with nnConv2d
you have to be tad more careful how you define the kernel size. With nn.Conv1d
you cannot simply set the kernel size incorrectly.
I hope that helps.