I’m converting some homegrown Keras code for attention to pytorch. In principle I can just use the MultiheadAttention module. The Keras code explicitly defines the weight matrices K, Q, and V. In the torch module, there are member attributes k_proj_weight, q_proj_weight, etc but these are initialized to None, and if I iterate through named_parameters() there are only in_proj_weight and out_proj_weight. What is the relationship between these matrices and K/Q/V?
I think it’s an optimization as can be seen here:
self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.batch_first = batch_first
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
if not self._qkv_same_embed_dim:
self.q_proj_weight = Parameter(torch.empty((embed_dim, embed_dim), **factory_kwargs))
self.k_proj_weight = Parameter(torch.empty((embed_dim, self.kdim), **factory_kwargs))
self.v_proj_weight = Parameter(torch.empty((embed_dim, self.vdim), **factory_kwargs))
self.register_parameter('in_proj_weight', None)
else:
self.in_proj_weight = Parameter(torch.empty((3 * embed_dim, embed_dim), **factory_kwargs))
self.register_parameter('q_proj_weight', None)
self.register_parameter('k_proj_weight', None)
self.register_parameter('v_proj_weight', None)
If the kdim
and vdim
are equal to embed_dim
, then in_proj_weight
will increase its size by 3
and should then contain these parameters if I’m reading the code right.