Why is torch.jit.script slower?

In a nutshell,

  1. “compilation” analyzes whole functions, with knowledge about variable types - some optimizations are done at this level (e.g. dead code elimination)
  2. python bytecode interpreter is not used to execute generated code - more specialized executor for statically typed code supposedly works faster
  3. fusion optimizations further compile specialized cuda kernels, so e.g. a.mul(b).add(c) is computed in one go
  4. some patterns have specialized optimizations, e.g. conv+batchnorm
1 Like