Hello!

I’m trying to implement Grouped Query Attention for a Vision Transformer but I cannot get the checkpoint conversion to work. The GQA paper states that the key and value tensors are mean pooled along the head axis, and more importantly that the **performance right after conversion is already decent** (little to no actual uptraining is required to get the model performing close to the MHA equivalent.

I have tried to get this to work with a Vision Transformer but right now the GQA variant pops out at most a 50% accuracy after the first epoch on my classification dataset, but MHA pops out closer to 90% so I know I must be doing something wrong if I’m not misreading the paper.

Here is the code so far:

```
class GQA(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
attn_drop: float = 0.,
proj_drop: float = 0.,
num_kv_heads: Optional[int] = None,
) -> None:
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.num_kv_heads = num_kv_heads if num_kv_heads is not None else (num_heads // 2) # have at least two heads in each group
self.q = nn.Linear(dim, dim, bias=qkv_bias)
self.k = nn.Linear(dim, self.num_kv_heads*self.head_dim, bias=qkv_bias)
self.v = nn.Linear(dim, self.num_kv_heads*self.head_dim, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, P, C = x.shape
H = self.num_heads
q = self.q(x).view(B, P, H, -1).transpose(1, 2) # (B, H, P, head_size)
k = self.k(x).view(B, P, self.num_kv_heads, -1).transpose(1, 2) # (B, num_kv_heads, P, head_size)
v = self.v(x).view(B, P, self.num_kv_heads, -1).transpose(1, 2) # (B, num_kv_heads, P, head_size)
q = q * self.scale
group_size = self.num_heads // self.num_kv_heads
q_grps = torch.split(q, group_size, dim=1)
k_grps = torch.split(k, 1, dim=1)
v_grps = torch.split(v, 1, dim=1)
outputs = [None] * len(k_grps)
for i in range(len(k_grps)):
# Collect items (note q has a larger head axis)
curr_q = q_grps[i] # (B, num_heads//num_kv_heads, num_patches, head_size)
curr_k = k_grps[i] # (B, 1, num_patches, head_size)
curr_v = v_grps[i] # (B, 1, num_patches, head_size)
scores = (curr_q @ curr_k.transpose(-2, -1))
weights = F.softmax(scores, dim=-1) # (B, num_heads//num_kv_heads, num_patches, num_patches)
weights = self.attn_drop(weights)
curr_att = weights @ curr_v # (B, num_heads//num_kv_heads, num_patches, head_size)
outputs[i] = curr_att
x = torch.cat(outputs, dim=1) # (B, num_heads, num_patches, head_size)
x = x.transpose(1, 2).contiguous().view(B, P, C) # (B, num_patches, emb_dim)
x = self.proj(x)
x = self.proj_drop(x)
return x
def att_weight_conversion(self, qkv_params, is_bias=False):
'''
Split and convert the QKV parameters from ViT checkpoints for the GQA implementation
'''
q, k, v = torch.split(qkv_params, qkv_params.shape[0] // 3, dim=0)
group_size = self.num_heads // self.num_kv_heads
def convert_weight(param):
x = param.clone()
# TODO: check whether to bring the heads axis at the front or middle
x = x.view(self.dim, self.num_heads, self.dim//self.num_heads)
xs = torch.split(x, group_size, dim=1) # split across head axis
xs = [xs[i].mean(dim=1) for i in range(len(xs))]
x = torch.cat(xs, dim=1)
expected_shape = (self.dim, self.num_kv_heads*self.dim//self.num_heads)
assert x.shape == expected_shape, f'Expected {expected_shape}, got {x.shape}'
return x
def convert_bias(param):
x = param.clone()
x = x.view(self.num_heads, self.dim//self.num_heads)
xs = torch.split(x, group_size, dim=0) # split across head axis
xs = [xs[i].mean(dim=0) for i in range(len(xs))]
x = torch.cat(xs, dim=0)
expected_shape = (self.num_kv_heads*self.dim//self.num_heads,)
assert x.shape == expected_shape, f'Expected {expected_shape}, got {x.shape}'
return x
return {
"q": q,
"k": convert_weight(k) if not is_bias else convert_bias(k),
"v": convert_weight(v) if not is_bias else convert_bias(v)
}
def load_pretrained_weights(self, state_dict, block_idx):
# Load in parameters for the Query Key Value layers
qkv_weight = state_dict[f'blocks.{block_idx}.attn.qkv.weight']
qkv_bias = state_dict[f'blocks.{block_idx}.attn.qkv.bias']
wdict = self.att_weight_conversion(qkv_weight)
bdict = self.att_weight_conversion(qkv_bias, is_bias=True)
self.q.weight = assign_check(self.q.weight, wdict['q'])
self.q.bias = assign_check(self.q.bias, bdict['q'])
self.k.weight = assign_check(self.k.weight, wdict['k'].T)
self.k.bias = assign_check(self.k.bias, bdict['k'])
self.v.weight = assign_check(self.v.weight, wdict['v'].T)
self.v.bias = assign_check(self.v.bias, bdict['v'])
# Load in parameters for the output projection
self.proj.weight = assign_check(self.proj.weight, state_dict[f'blocks.{block_idx}.attn.proj.weight'])
self.proj.bias = assign_check(self.proj.bias, state_dict[f'blocks.{block_idx}.attn.proj.bias'])
```

Please ignore the bulk of the forward pass unless there’s a glaring issue with it.

Hoping someone can help shed some light on what could be the issue.

Thank you!