Attention block without copies for contiguity

In dinoV2: dinov2/attention.py at main · facebookresearch/dinov2 · GitHub

the basic attention goes like this:

B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]

attn = q @ k.transpose(-2, -1)

the FlashAttention path goes like this:

B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
q, k, v =  xformers.ops.unbind(qkv, 2)

x = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=attn_bias, op=xformers.ops.fmha.MemoryEfficientAttentionFlashAttentionOp if attn_bias is not None else None)

Because of permutes / non-contig (notably, the number of heads must move to the batch dimensions), does it force any copies in matmul and in FlashAttention? (especially, the FlashAttention path)

If so, any hints on alleviating this copy? E.g. is it worth replacing the qkv transform by bmm, so that the unbind can go via the 0th dimension?