I tried to plug the linformer implementation from this repo GitHub - kuixu/Linear-Multihead-Attention: Reproducing the Linear Multihead Attention introduced in Linformer paper (Linformer: Self-Attention with Linear Complexity) into PyTorch, replacing the MultiheadAttention
in transformer.py
by linear_multihead_attention
class and forward function. Using seq_len
= 332 and project_k
=128.I printed the shape of attn_output_weights
: torch.Size([64, 332, 128]), and the attn_mask
: torch.Size([1, 332, 332]).