Point Op Fusion default

Hi,

I have come across numerous discussion that PyTorch JIT can fuse pointwise ops like layernorm, softmax etc and provide higher performance. Yet it seems most of the layers that do point ops do not use the JIT support for point op fusing by default. For example, when I call nn.LayerNorm, I do not see any fusion happening.

What is the reason for this and how can we use JIT support on existing layers without rebuilding pytorch?.

you’re misunderstanding, these are “vector” ops, elementwise ops are “scalar” ones with one-to-one input to output element dependencies. With this trivial dependency structure, fusion implementation is relatively simple and generic. And what is talked about is inter-op fusion, so it works with two or more operations (e.g. log(a*b+c)), not e.g. nn.LayerNorm by itself.

ps IIRC there are some JIT optimizations for some layer combinations, that’s like another form of fusion too

1 Like