Problems encountered in implementing Bert-like models

Hi everyone,

I’m currently trying to learn while implementing a bert-like network myself. I try to use it to perform cv feature extraction → classification of the Mnist dataset. I believe the task is simple and therefore suitable for learning. The transformer’s encoder may not be the best in image feature extraction, but it should be powerful enough for me to reach a conclusion.

The details of my implementation can be briefly outlined as follows:

  1. Using the Mnist dataset provided by tochvision, if the batch_size is set to 128, the dataloader returns 28*28 single channel images of shape [128, 1, 28, 28].
  2. Cut each image into 16 equal parts (4 equal parts in each direction) so that the data of an image can be interpreted as a sequence of 16 lengths and (28/4)**2=49 dimensions. → [128, 16, 49]
  3. Transformed to hidden dimension with dense layer → [128, 16, 64]
  4. Repeats after 6 transformer’s encoder layer, which is structurally very similar to the bert layer.
  5. Output the final classification result with MLP. → [128, 10]

I will include the full implementation code at the end of this article, for me, this implementation I have consulted a lot of online code, but there are a lot of details that I can’t understand, I would really appreciate if anyone can answer any of these:

  1. The sample code starts with the following transformation of the Mnist dataset, I wonder why this transformation is taking place, is it because of the requirement of performing positional embedding on the original data’s distribution?
transform_mnist = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.1307,), (0.3081,))
])
  1. The sample code provides a cls_token layer that changes the raw data from [128, 16, 64] into [128, 17, 64]. I don’t quite understand what this is used for.
  2. Whether to add dropout layer in multi head self attention. I quote a lot of code on the Internet, but they’re implemented in different ways. Some code adds dropout after the attention layer, some add dropout after the final output (of self-attention block), and some don’t add it at all. I would like to know if the dropout is used in the bert structure (and its various modifications) that is commonly used in production environments today.
  3. What should be the recommended order for layernorm, dropout, and the residual connection layer after the end of attention? I noticed that some code puts LayerNorm in front of the self-attention block, which is different from most others and confused me.
  4. I noticed that the positional encoding is only performed once in the sample code (instead of each encoder block performing its own positional encoding), does this mean that subsequent blocks lack the ability to capture the position?
  5. I noticed that the input to the final MLP, which should theoretically be the features extracted by the previous network, uses x[:, 0, :] , and I wonder why, is it because we’re doing a classification task? If our task type is trying to regress into a continuous field of values, should we also use x[:, 0, :] as MLP’s input?

Thank you all. Here is my full training code, and unfortunately, while it works, there seems to be some problem with the output and the loss that keep piling up. I don’t know what caused it.

import torch
import numpy as np
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from einops import rearrange

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
download_path = './data'
batch_size = 128

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, hidden_dim, heads, dropout_rate, bias=False):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.heads = heads
        self.d_k = hidden_dim // heads
        self.scale = np.sqrt(self.d_k)
        self.W_q, self.W_k, self.W_v, self.W_o = [nn.Linear(hidden_dim, hidden_dim, bias=bias) for _ in range(4)]
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x, mask=None):
        batch_size, seq_len, hidden_dim = x.shape
        # q, k, v -> [batch_size, 8, 17, 8]
        q = self.W_q(x).reshape(batch_size, seq_len, self.heads, self.d_k).permute(0, 2, 1, 3)
        k = self.W_k(x).reshape(batch_size, seq_len, self.heads, self.d_k).permute(0, 2, 1, 3)
        v = self.W_v(x).reshape(batch_size, seq_len, self.heads, self.d_k).permute(0, 2, 1, 3)

        # attention score
        attn = torch.matmul(q, k.transpose(2, 3)) / self.scale # -> [batch_size, 8, 17, 17]
        if mask is not None:
            ... # masking
        attn = attn.softmax(dim=-1)
        v = torch.matmul(attn, v).permute(0, 2, 1, 3).reshape(batch_size, seq_len, self.hidden_dim) # -> [batch_size, 8, 17, 8]
        return self.dropout(self.W_o(v)) # -> [batch_size, 17, 64]

class PositionWiseFeedForward(nn.Module):
    def __init__(self, hidden_dim, mlp_dim, dropout_rate):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.mlp_dim = mlp_dim
        self.dense1 = nn.Linear(hidden_dim, mlp_dim)
        self.relu = nn.ReLU()
        self.dense2 = nn.Linear(mlp_dim, hidden_dim)
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x):
        return self.dropout(self.dense2(self.relu(self.dense1(x))))

class EncoderBlock(nn.Module):
    def __init__(self, hidden_dim, heads, mlp_dim, dropout_rate):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.heads = heads
        self.self_attention_layer = MultiHeadSelfAttention(hidden_dim, heads, dropout_rate)
        self.ffn = PositionWiseFeedForward(hidden_dim, mlp_dim, dropout_rate)
        self.norm_layer1 = nn.LayerNorm(hidden_dim)
        self.norm_layer2 = nn.LayerNorm(hidden_dim)

    def forward(self, x):
        y = self.self_attention_layer(x)
        z = self.norm_layer1(x + y)
        return self.norm_layer1(z + self.ffn(z))

class MnistTransformerClsNet(nn.Module):
    def __init__(self, image_size, patch_div_num, class_num, depth, heads, hidden_dim, mlp_dim, channels=1):
        super().__init__()
        assert image_size % patch_div_num == 0, 'image dimensions must be divisible by the patch size'
        self.image_size = image_size
        self.patch_div_num = patch_div_num
        self.patch_size = image_size // patch_div_num

        self.convert_layer = nn.Linear(channels * self.patch_size ** 2, hidden_dim, bias=True)
        self.cls_token = nn.Parameter(torch.randn(1, 1, hidden_dim))
        self.pos_embedding = nn.Parameter(torch.randn(1, patch_div_num**2 + 1, hidden_dim))
        self.blocks = nn.Sequential()
        for _ in range(depth):
            self.blocks.add_module(f'block_{_}', EncoderBlock(hidden_dim, heads, mlp_dim, 0.1))
        self.mlp_layer = nn.Sequential(
            nn.Linear(hidden_dim, mlp_dim),
            nn.GELU(),
            nn.Linear(mlp_dim, class_num)
        )

    def forward(self, x): # input: x -> [batch_size, 1, 28, 28]
        p = self.patch_size
        x = rearrange(x, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p) # -> [batch_size, 16, 49]
        x = self.convert_layer(x) # -> [batch_size, 16, 64]

        cls_tokens = self.cls_token.expand(x.shape[0], -1, -1) # -> [batch_size, 1, 64]
        x = torch.cat((cls_tokens, x), dim=1) # -> [batch_size, 17, 64]
        x += self.pos_embedding # -> [batch_size, 17, 64]
        for _, block in enumerate(self.blocks):
            x = block(x) # -> [batch_size, 17, 64]
        return self.mlp_layer(x[:, 0]) # -> [batch_size, 64] -> [batch_size, 10]

def train():
    transform_mnist = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.1307,), (0.3081,))
    ])

    train_set = torchvision.datasets.MNIST(download_path, download=True, train=True, transform=transform_mnist)
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)

    net = MnistTransformerClsNet(image_size=28, patch_div_num=4, class_num=10, depth=6, heads=8, hidden_dim=64, mlp_dim=128).to(device)
    optimizer = torch.optim.Adam(net.parameters(), lr=0.002)

    for epoch in range(100):
        # train
        for step, (x, y) in enumerate(train_loader):
            x, y = x.to(device), y.to(device)
            output = net(x).to(device)
            optimizer.zero_grad()
            loss = F.nll_loss(output, y)
            loss.backward()
            optimizer.step()
            if step % 100 == 0:
                print('[' + '{:5}'.format(step * len(x)) + '/' + '{:5}'.format(len(train_loader.dataset)) +
                      ' (' + '{:3.0f}'.format(100 * step / len(train_loader)) + '%)]  Loss: ' +
                      '{:6.4f}'.format(loss.item()))

if __name__ == '__main__':
    train()