Self Self-attention implementation results are 'a bit' suprising

I am studying attention so I was making some experiments on implementing. I was curious of each part’s (Residual connections, Layer normalization,
Dropout, Multi-heads) importance on training so I tried to implement them myself and make some tests.

I saw that using attention without any residual connections was worse than using no attention at all. After adding the residual connections the performance got closer to base NoAttentionAttention version but still it was a bit worse off. Only after adding the Layer Normalization did I start to get better results.

My quetions is that is this behaviour normal? It seems using a simpler feed forward layer instead of attention with residual connections is a bit better and much faster. I have 3 explanations in my mind but I don’t know which one is correct (maybe couple of them are correct at the same time to a degree or all of them are wrong). 1 - Layer normalization is crucial part of the architecture and attention doesn’t quite work without it. 2 - My implementation of attention mechanism is wrong/I’ve overlooked some details. 3 - My setup is not quite suitable for showing the benefits of using transformers. I am training a character level model and my training data is Andrej Karpathy’s tiny_shakespeare dataset, actually I am not sure if he created the dataset or he just used it and made it popular, I am sorry if my source is wrong.

My results are as follows:

NoAttentionAttention
min_train_loss1.701506521344185
min_test_loss1.5964061438009656
max_train_accuracy49.58238525390625
max_test_accuracy53.14371043238146
duration174.00348663330078
SimpleAttention
min_train_loss2.2374376397132876
min_test_loss2.024268022076837
max_train_accuracy40.8646728515625
max_test_accuracy46.13544332570043
duration1405.215410232544
ResidualAttention
min_train_loss1.894147829771042
min_test_loss1.66307471086239
max_train_accuracy46.54041748046875
max_test_accuracy52.319441170528016
duration1434.7716104984283
LayNormResAttention
min_train_loss0.36205335560441015
min_test_loss1.1429171343301905
max_train_accuracy89.10374755859375
max_test_accuracy67.51815269733298
duration1451.627137184143

These are my self-attention implementations:

class NoAttentionAttention(nn.Module):
    def __init__(self, emb_size, head_size):
        super(NoAttentionAttention, self).__init__()
        self.linear = nn.Linear(emb_size, head_size)
    def forward(self, x):
        x = nn.functional.relu(self.linear(x))
        return x

class SimpleAttention(nn.Module):
    def __init__(self, emb_size, head_size):
        super(SimpleAttention, self).__init__()
        self.emb_size = emb_size
        self.Q = nn.Linear(emb_size, head_size)
        self.K = nn.Linear(emb_size, head_size)
        self.V = nn.Linear(emb_size, head_size)
        self.linear = nn.Linear(head_size, head_size)
    def forward(self, x):
        q = self.Q(x)
        k = self.K(x)
        v = self.V(x)
        att_w = (q @ k.transpose(1, 2)) * self.emb_size ** (-1 / 2)
        seq_len = x.size(1)
        mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
        att_w = att_w.masked_fill(mask, float('-inf'))
        att_s = nn.functional.softmax(att_w, dim=-1)
        att_out = att_s @ v
        x = nn.functional.relu(self.linear(att_out))
        return x

class ResidualAttention(nn.Module):
    def __init__(self, emb_size, head_size):
        super(ResidualAttention, self).__init__()
        self.emb_size = emb_size
        self.Q = nn.Linear(emb_size, head_size)
        self.K = nn.Linear(emb_size, head_size)
        self.V = nn.Linear(emb_size, head_size)
        self.linear = nn.Linear(head_size, head_size)
    def forward(self, x):
        q = self.Q(x)
        k = self.K(x)
        v = self.V(x)
        att_w = (q @ k.transpose(1, 2)) * self.emb_size ** (-1 / 2)
        seq_len = x.size(1)
        mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
        att_w = att_w.masked_fill(mask, float('-inf'))
        att_s = nn.functional.softmax(att_w, dim=-1)
        att_out = att_s @ v
        x = x + att_out
        x = torch.nn.functional.relu(self.linear(x))
        return x

class LayNormResAttention(nn.Module):
    def __init__(self, emb_size, head_size):
        super(LayNormResAttention, self).__init__()
        self.emb_size = emb_size
        self.Q = nn.Linear(emb_size, head_size)
        self.K = nn.Linear(emb_size, head_size)
        self.V = nn.Linear(emb_size, head_size)
        self.linear = nn.Linear(head_size, head_size)
        self.layer_norm = nn.LayerNorm(emb_size)
    def forward(self, x):
        q = self.Q(x)
        k = self.K(x)
        v = self.V(x)
        att_w = (q @ k.transpose(1, 2)) * self.emb_size ** (-1 / 2)
        seq_len = x.size(1)
        mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
        att_w = att_w.masked_fill(mask, float('-inf'))
        att_s = nn.functional.softmax(att_w, dim=-1)
        att_out = att_s @ v
        x = self.layer_norm(x + att_out)
        x = torch.nn.functional.relu(self.linear(x))
        return x

Another suprising result I’ve come accross with the same setup is when I used multi heads.

With this implementation:

class TorchMhAttention(nn.Module):
    def __init__(self, emb_size, head_size, dropout=0.1):
        super(TorchMhAttention, self).__init__()
        self.emb_size = emb_size
        self.head_size = head_size
        self.multihead_attn = nn.MultiheadAttention(emb_size, num_heads=4, dropout=dropout)
        self.layer_norm = nn.LayerNorm(emb_size)
        self.dropout = nn.Dropout(dropout)
        self.linear = nn.Linear(emb_size, emb_size)
    def forward(self, x):
        x = x.transpose(0, 1) 
        seq_len = x.size(0)
        mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool().to(x.device)
        att_out, _ = self.multihead_attn(x, x, x, attn_mask=mask)
        x = self.layer_norm(x + att_out)
        x = torch.nn.functional.relu(self.linear(x))
        x = self.dropout(x)
        return x.transpose(0, 1)

when num_heads=1 I get these scores:

min_train_loss1.2817998005747795
min_test_loss1.1874091575371808
max_train_accuracy61.50433349609376
max_test_accuracy64.9322299299569
duration1759.3112213611603

but when I increased the num_heads to 4 I got this result:

min_train_loss2.2550643470287324
min_test_loss2.189865370043393
max_train_accuracy41.768908691406246
max_test_accuracy45.89278648639548
duration2077.635309934616

For all of these tests my emb_size == head_size == 256 and I use 6 layers of my attention implementations for each test, my context length is 128 characters.