Hi everyone,
I’m currently trying to learn while implementing a bert-like network myself. I try to use it to perform cv feature extraction → classification of the Mnist dataset. I believe the task is simple and therefore suitable for learning. The transformer’s encoder may not be the best in image feature extraction, but it should be powerful enough for me to reach a conclusion.
The details of my implementation can be briefly outlined as follows:
- Using the Mnist dataset provided by
tochvision
, if the batch_size is set to 128, the dataloader returns 28*28 single channel images of shape[128, 1, 28, 28].
- Cut each image into 16 equal parts (4 equal parts in each direction) so that the data of an image can be interpreted as a sequence of 16 lengths and
(28/4)**2=49
dimensions. →[128, 16, 49]
- Transformed to hidden dimension with dense layer →
[128, 16, 64]
- Repeats after 6 transformer’s encoder layer, which is structurally very similar to the bert layer.
- Output the final classification result with MLP. →
[128, 10]
I will include the full implementation code at the end of this article, for me, this implementation I have consulted a lot of online code, but there are a lot of details that I can’t understand, I would really appreciate if anyone can answer any of these:
- The sample code starts with the following transformation of the Mnist dataset, I wonder why this transformation is taking place, is it because of the requirement of performing positional embedding on the original data’s distribution?
transform_mnist = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.1307,), (0.3081,))
])
- The sample code provides a
cls_token
layer that changes the raw data from[128, 16, 64]
into[128, 17, 64]
. I don’t quite understand what this is used for. - Whether to add dropout layer in multi head self attention. I quote a lot of code on the Internet, but they’re implemented in different ways. Some code adds dropout after the attention layer, some add dropout after the final output (of self-attention block), and some don’t add it at all. I would like to know if the dropout is used in the bert structure (and its various modifications) that is commonly used in production environments today.
- What should be the recommended order for layernorm, dropout, and the residual connection layer after the end of attention? I noticed that some code puts LayerNorm in front of the self-attention block, which is different from most others and confused me.
- I noticed that the positional encoding is only performed once in the sample code (instead of each encoder block performing its own positional encoding), does this mean that subsequent blocks lack the ability to capture the position?
- I noticed that the input to the final MLP, which should theoretically be the features extracted by the previous network, uses
x[:, 0, :]
, and I wonder why, is it because we’re doing a classification task? If our task type is trying to regress into a continuous field of values, should we also usex[:, 0, :]
as MLP’s input?
Thank you all. Here is my full training code, and unfortunately, while it works, there seems to be some problem with the output and the loss that keep piling up. I don’t know what caused it.
import torch
import numpy as np
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from einops import rearrange
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
download_path = './data'
batch_size = 128
class MultiHeadSelfAttention(nn.Module):
def __init__(self, hidden_dim, heads, dropout_rate, bias=False):
super().__init__()
self.hidden_dim = hidden_dim
self.heads = heads
self.d_k = hidden_dim // heads
self.scale = np.sqrt(self.d_k)
self.W_q, self.W_k, self.W_v, self.W_o = [nn.Linear(hidden_dim, hidden_dim, bias=bias) for _ in range(4)]
self.dropout = nn.Dropout(dropout_rate)
def forward(self, x, mask=None):
batch_size, seq_len, hidden_dim = x.shape
# q, k, v -> [batch_size, 8, 17, 8]
q = self.W_q(x).reshape(batch_size, seq_len, self.heads, self.d_k).permute(0, 2, 1, 3)
k = self.W_k(x).reshape(batch_size, seq_len, self.heads, self.d_k).permute(0, 2, 1, 3)
v = self.W_v(x).reshape(batch_size, seq_len, self.heads, self.d_k).permute(0, 2, 1, 3)
# attention score
attn = torch.matmul(q, k.transpose(2, 3)) / self.scale # -> [batch_size, 8, 17, 17]
if mask is not None:
... # masking
attn = attn.softmax(dim=-1)
v = torch.matmul(attn, v).permute(0, 2, 1, 3).reshape(batch_size, seq_len, self.hidden_dim) # -> [batch_size, 8, 17, 8]
return self.dropout(self.W_o(v)) # -> [batch_size, 17, 64]
class PositionWiseFeedForward(nn.Module):
def __init__(self, hidden_dim, mlp_dim, dropout_rate):
super().__init__()
self.hidden_dim = hidden_dim
self.mlp_dim = mlp_dim
self.dense1 = nn.Linear(hidden_dim, mlp_dim)
self.relu = nn.ReLU()
self.dense2 = nn.Linear(mlp_dim, hidden_dim)
self.dropout = nn.Dropout(dropout_rate)
def forward(self, x):
return self.dropout(self.dense2(self.relu(self.dense1(x))))
class EncoderBlock(nn.Module):
def __init__(self, hidden_dim, heads, mlp_dim, dropout_rate):
super().__init__()
self.hidden_dim = hidden_dim
self.heads = heads
self.self_attention_layer = MultiHeadSelfAttention(hidden_dim, heads, dropout_rate)
self.ffn = PositionWiseFeedForward(hidden_dim, mlp_dim, dropout_rate)
self.norm_layer1 = nn.LayerNorm(hidden_dim)
self.norm_layer2 = nn.LayerNorm(hidden_dim)
def forward(self, x):
y = self.self_attention_layer(x)
z = self.norm_layer1(x + y)
return self.norm_layer1(z + self.ffn(z))
class MnistTransformerClsNet(nn.Module):
def __init__(self, image_size, patch_div_num, class_num, depth, heads, hidden_dim, mlp_dim, channels=1):
super().__init__()
assert image_size % patch_div_num == 0, 'image dimensions must be divisible by the patch size'
self.image_size = image_size
self.patch_div_num = patch_div_num
self.patch_size = image_size // patch_div_num
self.convert_layer = nn.Linear(channels * self.patch_size ** 2, hidden_dim, bias=True)
self.cls_token = nn.Parameter(torch.randn(1, 1, hidden_dim))
self.pos_embedding = nn.Parameter(torch.randn(1, patch_div_num**2 + 1, hidden_dim))
self.blocks = nn.Sequential()
for _ in range(depth):
self.blocks.add_module(f'block_{_}', EncoderBlock(hidden_dim, heads, mlp_dim, 0.1))
self.mlp_layer = nn.Sequential(
nn.Linear(hidden_dim, mlp_dim),
nn.GELU(),
nn.Linear(mlp_dim, class_num)
)
def forward(self, x): # input: x -> [batch_size, 1, 28, 28]
p = self.patch_size
x = rearrange(x, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p) # -> [batch_size, 16, 49]
x = self.convert_layer(x) # -> [batch_size, 16, 64]
cls_tokens = self.cls_token.expand(x.shape[0], -1, -1) # -> [batch_size, 1, 64]
x = torch.cat((cls_tokens, x), dim=1) # -> [batch_size, 17, 64]
x += self.pos_embedding # -> [batch_size, 17, 64]
for _, block in enumerate(self.blocks):
x = block(x) # -> [batch_size, 17, 64]
return self.mlp_layer(x[:, 0]) # -> [batch_size, 64] -> [batch_size, 10]
def train():
transform_mnist = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.1307,), (0.3081,))
])
train_set = torchvision.datasets.MNIST(download_path, download=True, train=True, transform=transform_mnist)
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
net = MnistTransformerClsNet(image_size=28, patch_div_num=4, class_num=10, depth=6, heads=8, hidden_dim=64, mlp_dim=128).to(device)
optimizer = torch.optim.Adam(net.parameters(), lr=0.002)
for epoch in range(100):
# train
for step, (x, y) in enumerate(train_loader):
x, y = x.to(device), y.to(device)
output = net(x).to(device)
optimizer.zero_grad()
loss = F.nll_loss(output, y)
loss.backward()
optimizer.step()
if step % 100 == 0:
print('[' + '{:5}'.format(step * len(x)) + '/' + '{:5}'.format(len(train_loader.dataset)) +
' (' + '{:3.0f}'.format(100 * step / len(train_loader)) + '%)] Loss: ' +
'{:6.4f}'.format(loss.item()))
if __name__ == '__main__':
train()