In dinoV2: dinov2/attention.py at main · facebookresearch/dinov2 · GitHub
the basic attention goes like this:
B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv * self.scale, qkv, qkv attn = q @ k.transpose(-2, -1)
the FlashAttention path goes like this:
B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) q, k, v = xformers.ops.unbind(qkv, 2) x = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=attn_bias, op=xformers.ops.fmha.MemoryEfficientAttentionFlashAttentionOp if attn_bias is not None else None)
Because of permutes / non-contig (notably, the number of heads must move to the batch dimensions), does it force any copies in matmul and in FlashAttention? (especially, the FlashAttention path)
If so, any hints on alleviating this copy? E.g. is it worth replacing the
qkv transform by bmm, so that the unbind can go via the 0th dimension?