Why does the skip connection in a transformer decoder's residual cross attention block come from the queries rather than the values?

Transformer’s residual transformer decoder cross attention layer use keys and values from the encoder, and queries from the decoder. These residual layers implement out = x + F(x).

As implemented in the PyTorch source code, and as the original transformer diagram shows, the residual layer skip connection comes from the queries (arrow coming out of decoder self-attention), not from the keys and values (arrow coming from the encoder). That is, out = queries + F(queries, keys, values) is implemented.

The values indicate “How each input is represented”, and the point of the attention layer is to created a weighted average of the available values, so it seems like out = values + F(queries, keys, values) would make more sense.

Why is this not implemented?

1 Like