I am looking to use a model that was coded up using flash attention < 1.0.5 under the hood.
I would like to port the weights to my own implementation where I am simply using torch.nn.MultiheadAttention.
Now my question:
Am I guaranteed, that flash attentions weights
transformer_encoder.layers.1.self_attn.Wqkv.weight
are compatible with the weights for MultiheadAttention called
transformer_encoder.layers.1.attention.in_proj_weight ?
The tensors have the same shape. What I’m essentially asking, is the implementation equivalent in the sense that the Wq, Wv and Wk matrices are ordered in the same way within these tensors?
While not used in this instance (as attention bias was false by default), I’d be interested to know whether I also can just port the attention biases (again, same shape) without worrying.