What is the Big O runtime of a transformer model during inference? How about during training? Seems to me that it is still O(n), since it can only generate each word of the output one at a time.
it depends on whther you use autoregressive predictions or not.
If that’s the case, it’s O(n2), otherwise it’s O(n). There are some other types of transformer with different times.
Transformers are supposed to do in parallel at least part of what RNN seq2seq models can’t. So what would you say the Big O is for an RNN seq2seq?
By the way, I thought an autoregressive transformer would be O(n), where n = length of the target sequence. Could you explain why it would be O(n2)?
Transformers are fully parallelized as they run matrix multiplications + linear layers. As linear layers process each input independently to the others, transformers reach a high parallelization. On contrary, RNNs requires the output of the previous iteration which leads to a bottleneck.
The transformer itself (but the newer/optimized versions in research) are O(n) where n is the lenght of the sequence. It’s as simple as for each element in the input we need a single run to obtain the corresponding output.
The autoregressive versions run iteratively. For example, in translation (which is a seq2seq task where the length of the output is unknown).
You process the context (known) and obtain features with an encoder.
Then you need to obtain the translated version.
Run the transformer once with the start token. Obtains the first prediction.
Feed the start token + first pred–> 2nd pred
feed start, 1st,2nd and so on
In short, each time you run the transformer you are predictin the next (and only one) element of the sequence. Therefore, if your tgt sentence is of lenght n, you have to run the decoder n + n-1 + n-2 —> ~ O(n^2)
Thanks for the reply! What I’m confused about is - if the transformer only has an encoder, then everything is done in parallel, so no matter the size of the input, the runtime is constant? That sounds like O(1) to me
Soo I guess it depends on the base unit we consider here
I’m talking taking as unit the transformer processing a single element of the sequence.
In that case the “encoder” is O(n) and an autoregressive would be O(n^2).
If you consider the whole transformer as unit (no matter the sequence length)
Then the encoder is O(1) and the autoregressive is O(n)
Note that in practice, the sequence length may be a bottleneck. If your gpu has enough capability to run the whole seq at once no matter the seq length, we are in the second case. If your gpu is not powerful enough, it would make more sense to choose the former perspective.
EDIT: Another way of seeing it is https://www.reddit.com/r/LanguageTechnology/comments/9gulm9/complexity_of_transformer_attention_network/
Taking into account the multi-head self attention complexity.
Thanks for the detailed answer! I never thought about the GPU capability influencing the runtime or complexity, but it does make sense
Well, being strictly precise, it depends on the sequence length as stated above.
Taking into account that, “all the ops” run in parallel in the GPU, i’d say the limitation depends on the capability. Note that it’s not my area of expertise. Don’t take my words as carved on the stone!