Inputs for torch.nn.MultiheadAttention

Dear Community,

I am reviewing the API definition and documentation for “torch.nn.MultiheadAttention”. It appears the forward method needs query , key , value. Does this mean query, key, value needs to be learnt outside “torch.nn.MultiheadAttention”?

Assuming x is the input, does this mean the following is incorrect? (we are directly passing the same input x for query, key and value instead of obtaining these embeddings)

self_attent = torch.nn.MultiheadAttention(embed_dim=256 , num_heads=8)
attn_output, attn_output_weights = self_attent(x, x, x)

Please advise.

Thanks !!!

it’ll do the projections internally. see parameters like q_proj_weight in torch.nn.modules.activation — PyTorch 2.1 documentation

Code like in your example self_attent(x, x, x) is pretty typical, and the k/q/v get projected from that just fine. Distinct k/q/v can matter in cases like cross attention, because the size of the query doesn’t have to be the same as the size of the k/v. If you’re doing a decoder-only or encoder-only model it shouldn’t matter.

You can find more detail in how specific cases are implemented in attention.h, e.g.

1 Like

Thankyou. This is very helpful.