BERT Based CNN - Convolution and Maxpooling

I’m trying to fine-tune a pre-trainted BERT model by inserting a CNN layer.
In this model, the outputs of all transformer encoders are used, not only the output of the latest
transformer encoder.
So that the output vectors of each transformer encoder are concatenated, and a matrix is produced:

The convolutional operation is performed with a window of size (3, hidden size of BERT which is 768 in BERT_base model) and the maximum value is generated for each transformer encoder by
applying max pooling on the convolution output.

By concatenating these values, a vector is generated which is given as input to a fully connected network. By applying softmax on the input, the classification operation is performed.

My problem is that I can’t seem to find the right arguments to perform the convolution and the maxpooling on that matrix.

With batch size = 32, there are 13 layers of Transformer encoders, each one get as an input [64, 768] of encoded tokenized text and outputs an encoding of the same dimensions. (64 is the max-length in tokenization)

I want to perform convolution on each transformer’s output matrix ([64,768]) separately, then perform maxpooling on that convolution’s output. So, I should get a max-value per each transformer, and these max values are inserted into the neural network.

My code is:

class BERT_Arch(nn.Module):

    def __init__(self, bert):
        super(BERT_Arch, self).__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.conv = nn.Conv2d(in_channels=13, out_channels=13, kernel_size= (3, 768), padding=True) 
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool1d(kernel_size=768, stride=1)
        self.dropout = nn.Dropout(0.1)
        self.fc = nn.Linear(9118464, 3)
        self.flat = nn.Flatten()
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, sent_id, mask):
        _, _, all_layers = self.bert(sent_id, attention_mask=mask, output_hidden_states=True)
        # all_layers  = [32, 13, 64, 768]
        x = torch.cat(all_layers, 0) # x= [416, 64, 768]
        x = self.conv(x)
        x = self.relu(x)
        x = self.pool(x)
        x = self.flat(x)
        x = self.fc(x)
        return self.softmax(x)

I keep getting an error saying that the convolution method expected a certain dimensions as input but got a different one.

<generator object BERT_Arch.forward.<locals>.<genexpr> at 0x7fbeffc2d200>
torch.Size([416, 64, 768])
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-12-3a2c2cd7c02d> in <module>()
    362 
    363         # train model
--> 364         train_loss, _ = train()
    365 
    366         # evaluate model

5 frames
<ipython-input-12-3a2c2cd7c02d> in train()
    148 
    149         # get model predictions for the current batch
--> 150         preds = model(sent_id, mask)
    151 
    152         # compute the loss between actual and predicted values

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

<ipython-input-12-3a2c2cd7c02d> in forward(self, sent_id, mask)
     42         x = torch.cat(all_layers, 0) # torch.Size([13, 32, 64, 768])
     43         print(x.shape)
---> 44         x = self.conv(x)
     45         x = self.relu(x)
     46         x = self.pool(x)

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/conv.py in forward(self, input)
    421 
    422     def forward(self, input: Tensor) -> Tensor:
--> 423         return self._conv_forward(input, self.weight)
    424 
    425 class Conv3d(_ConvNd):

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/conv.py in _conv_forward(self, input, weight)
    418                             _pair(0), self.dilation, self.groups)
    419         return F.conv2d(input, weight, self.bias, self.stride,
--> 420                         self.padding, self.dilation, self.groups)
    421 
    422     def forward(self, input: Tensor) -> Tensor:

RuntimeError: Expected 4-dimensional input for 4-dimensional weight [13, 13, 3, 768], but got 3-dimensional input of size [416, 64, 768] instead

I tried different values for the convolution method arguments, I still got a similar error.
And sometimes an error saying that the maxpooling output size is too small:

Given input size: (64x62x1). Calculated output size: (64x31x0). Output size is too small

I would be grateful for any help on how to do this CNN layer correctly.