Multi Headed Attention: Query & Value Transformations acquire extremely low gradients

Hi,

I am trying to execute a version of multi headed attention on input batches of sequence length 10. Below is a simplified version of my code:

type or paste code here

class MultiHeadAttention(nn.Module):
    # Multi-Head Attention module

    def __init__(self, n_head, d_model, embed_dim_per_head, dropout=0.1):
        super().__init__()

        total_embed_dim = embed_dim_per_head * n_head
        self.w_qs =  nn.Linear(d_model, total_embed_dim, bias=False)
        self.w_ks =  nn.Linear(d_model, total_embed_dim, bias=False)
        self.w_vs =  nn.Linear(d_model, total_embed_dim, bias=False)
        self.fc = nn.Linear(total_embed_dim, d_model)

        self.attention = nn.MultiheadAttention(
                embed_dim = total_embed_dim,
                num_heads = n_head,
                dropout = 0.1)

        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)

    def forward(self, q, k, v):

        residual = q

        # NB, SEQ_LEN, (N_HEAD * D) -> SEQ_LEN, NB, (N_HEAD * D)
        q = self.w_qs(q).permute(1, 0, 2)
        k = self.w_ks(k).permute(1, 0 ,2)
        v = self.w_vs(v).permute(1, 0, 2)
        q, att_map = self.attention(q, k, v, need_weights=True)
        print(f'AttMap Max Values: {att_map[0].max(dim=-1)[0]}')
        # SEQ_LEN, NB, (N_HEAD * D) -> NB, SEQ_LEN, (N_HEAD * D)
        q = q.permute(1, 0, 2)
        q = self.dropout(self.fc(q))
        q += residual
        q = self.layer_norm(q)
        return q


class IntraModal_MHA(nn.Module):
    def __init__(self, num_blocks, num_heads, transform_dim, embed_dim_per_head, activation):
        super(IntraModal_MHA, self).__init__()
        self.activation = activation
        self.num_blocks = num_blocks
        self.mha_list = nn.ModuleList([
             MultiHeadAttention(num_heads, transform_dim, embed_dim_per_head)
             for i in range(num_blocks)])
        self.linear_list = nn.ModuleList([
            nn.Sequential(
                nn.Linear(transform_dim, 2 * transform_dim),
                self.activation,
                nn.Dropout(0.1),
                nn.Linear(2 * transform_dim, transform_dim),
                self.activation,
                nn.Dropout(0.1),
                )
            for i in range(num_blocks)])
        self.layer_norms = nn.ModuleList([
            nn.LayerNorm(transform_dim, eps=1e-6)
            for i in range(num_blocks)])

    def forward(self, modal_ft):
        for i in range(self.num_blocks):
            output = self.mha_list[i](modal_ft, modal_ft, modal_ft)
            residual = output
            output = self.linear_list[i](output)
            output += residual
            modal_ft = self.layer_norms[i](output)
        return modal_ft


class A2H(nn.Module):
    def __init__(self):
        super(A2H, self).__init__()
        self.activation = nn.ReLU(inplace=True)

        self.audio_transform = nn.Sequential(
                nn.Linear(128, 512),
                self.activation,
                nn.Dropout(0.1),
                )

        self.a2a_temporal_att = IntraModal_MHA(1, 1, 512, 512, self.activation)
        self.classifier_fc = nn.Linear(512, 29)

    @staticmethod
    def positional_encoding(n_position, emb_dim):
            #The sinusoid position encoding table

            position_enc = np.array([
            [pos / np.power(10000, 2 * (j // 2) / emb_dim) for j in range(emb_dim)]
            for pos in range(n_position)])
            position_enc[:, 0::2] = np.sin(position_enc[:, 0::2]) # dim 2i
            position_enc[:, 1::2] = np.cos(position_enc[:, 1::2]) # dim 2i+1
            return torch.from_numpy(position_enc).type(torch.FloatTensor)

    def forward(self, audio_ft, unused_):
        nb, seq_len, aud_ft_dim = audio_ft.shape
        audio_ft_pe = audio_ft + self.positional_encoding(seq_len, aud_ft_dim).cuda()
        audio_ft_pe = self.audio_transform(audio_ft_pe)
        a2a_temporal_out = self.a2a_temporal_att(audio_ft_pe)
        signal_out = self.classifier_fc(a2a_temporal_out)
        classifier_out = F.softmax(signal_out, dim=-1)
        return classifier_out

My optimizer is the typical SGD with momentum setting with a high LR of 0.2.
Below are the outputs for my prints which display

  1. The gradients to the nn.Linear which transforms my input into Query and Key matrices
  2. The maximum value of the 1st batch’s attention returned by nn.MultiheadAttention:

After Update:
a2a_temporal_att.mha_list.0.w_qs.weight : 3.750015922787675e-12
a2a_temporal_att.mha_list.0.w_ks.weight : 3.718521150719578e-12
AttMap Max Values: tensor([0.1128, 0.1126, 0.1123, 0.1128, 0.1134, 0.1130, 0.1128, 0.1122, 0.1121,
0.1125], device=‘cuda:0’, grad_fn=)
After Update:
a2a_temporal_att.mha_list.0.w_qs.weight : 4.729227426336635e-12
a2a_temporal_att.mha_list.0.w_ks.weight : 4.986864220180021e-12
AttMap Max Values: tensor([0.1131, 0.1140, 0.1128, 0.1137, 0.1142, 0.1136, 0.1132, 0.1129, 0.1123,
0.1122], device=‘cuda:0’, grad_fn=)
After Update:
a2a_temporal_att.mha_list.0.w_qs.weight : 2.1090891186292815e-11
a2a_temporal_att.mha_list.0.w_ks.weight : 1.926272509555904e-11
AttMap Max Values: tensor([0.1123, 0.1121, 0.1126, 0.1123, 0.1123, 0.1122, 0.1118, 0.1124, 0.1122,
0.1123], device=‘cuda:0’, grad_fn=)
After Update:
a2a_temporal_att.mha_list.0.w_qs.weight : 1.62918636914533e-11
a2a_temporal_att.mha_list.0.w_ks.weight : 1.50868172121843e-11

As you can see my implementation of IntraModal_MHA contains residual connections to avoid the vanishing gradient problem and resembles to that of the transformer structure.

Below is the gradient accumulation for the entire network in the 6th epoch during which my accuracy increased from 28% to 33%:

6_10

The gradients to both the w_qs and w_ks are always super low and the attention map is not converging into a sharp distribution. The input sequences contain foreground and background type samples and the objective is to classify the foreground samples. From the above objective, the attention map is expected to be sharp. Additionally, as I increase the complexity of the network by increasing the number of blocks in the IntraModal_MHA, the gradients become even more smaller.

Has anyone faced this or can someone please point out any obvious mistake I might be doing?

Thanks a lot!

For future readers: may want to check out this article: Tutorial #17: Transformers III Training - Borealis AI.

TL;DR: Try using the adam optimizer.

Transformers also differ from convolutional networks in that stochastic gradient descent does not work well for training (figure 2) and adaptive optimizers like Adam are required. Liu et al., 2020 observed that differentiating through the self-attention mechanism creates unbalanced gradients. In particular, the gradients for the query Φq and key Φk parameters were much smaller than those for the value parameters Φv , and so the former parameters change much more slowly. This is a direct consequence of the mathematical expression for self-attention. The Adam optimizer fixes this problem by essentially having different learning rates for each parameter. To conclude, we’ve seen that residual connections are needed to allow us to train deep networks. These cause gradient explosion, which is resolved by using layer normalization. The self-attention computation causes unbalanced gradients, which necessitates the use of Adam (figure 4). In the next section, we’ll see that layer normalization and Adam themselves cause more problems, which ultimately result in the need for learning rate warm-up.

1 Like