How sparse transformer reduces memory complexity

I’m trying to implement the model name “sparse transformer” with pytorch.

As far as I check with fairseq open source for sparse attention mechanism, they simply added the mask matrix with original QK dot product matrix (trg_seq_len ,src_seq_len).

I wonder that above operation is able to reduce memory complexity even though it still calculate n**2 dot product? then could you let me know how?

1 Like