Transformer’s residual transformer decoder cross attention layer use keys and values from the encoder, and queries from the decoder. These residual layers implement out = x + F(x).
As implemented in the PyTorch source code, and as the original transformer diagram shows, the residual layer skip connection comes from the queries (arrow coming out of decoder self-attention), not from the keys and values (arrow coming from the encoder). That is, out = queries + F(queries, keys, values) is implemented.
The values indicate “How each input is represented”, and the point of the attention layer is to created a weighted average of the available values, so it seems like out = values + F(queries, keys, values) would make more sense.
Why is this not implemented?
This answer is from chatgpt:
The design choice of using queries for the skip connection in the residual cross-attention block of a transformer decoder, rather than the values, is rooted in the fundamental working principles of transformers and the role of each component (queries, keys, and values) in the attention mechanism. Let’s break this down:
Understanding the Roles of Queries, Keys, and Values:
- Queries: Represent the current state of the decoder. In the context of sequence-to-sequence models, this could be the partially decoded sequence at a given step.
- Keys and Values: Derived from the encoder, they represent the information of the input sequence. Keys are used to compute attention weights, and values are used to construct the output based on these weights.
Why Queries for Skip Connections?
Maintaining Sequential Context: The primary role of the decoder is to generate the output sequence step-by-step. The queries represent the current ‘thought’ or state of the decoder at each step. The skip connection from queries ensures that this current state is always directly influencing the output, maintaining the sequential context inherent in the task.
Residual Learning: In the original transformer architecture, the principle of residual learning (
out = x + F(x)) is applied. Here,
x is the input to a layer, and
F(x) is the transformation applied by that layer. By using the queries (
x) in the residual connection, the architecture ensures that each layer can learn to modify the current state incrementally. This is in line with the idea of residual networks where each layer ideally learns a residual modification to the identity function.
Consistency in the Decoder: Throughout the decoder, each sub-layer (self-attention and cross-attention) uses its input (queries) for the residual connection. This consistency is important for the model’s learning process, as it allows each sub-layer to refine the decoder’s current state.
Values as Contextual Information: The role of values is to provide contextual information from the encoder, used to generate the weighted average output based on the queries. This output is then added to the queries through the residual connection. If the values were used for the residual connection, it would mean directly adding raw encoder information to the output without the decoder’s current context, which could lead to a loss of sequential coherence in the decoder’s output.
- Attention as a Refinement Process: The attention mechanism can be viewed as a way to refine the decoder’s current state (queries) by considering additional information (keys and values from the encoder). The skip connection supports this view by ensuring that the refinement is an additive process to the current state.
- Model Stability and Training Efficiency: Empirically, it has been found that using queries for residual connections provides a stable training dynamic and helps the model converge effectively. This might be due to the way the transformer balances learning from the encoder’s context and the decoder’s own generated context.
In summary, using queries for skip connections in the transformer decoder’s cross-attention block is a design choice that aligns with the transformer’s overall architecture and its goal of sequential generation. It ensures the coherence and consistency of the decoder’s output while maintaining a stable and efficient learning process.