JIT graph doesn't show fused kernels?

For example a simple jit’ed mse:

@torch.jit.script
def jit_mse(input, target):
    return ((input - target)**2).mean()

shows the cuda kernel fused_sub_pow in the profiler, however it doesn’t change the original graph (with aten::sub and aten::pow)
This seems to make the graph property a bit useless, or is this not its intended usage?
Cheers

The graph is the unoptimized representation, see graph_for(input, target) for the “full deal”. So the graph property is more “how did TorchScript understand my program” than while graph_for is “what will TorchScript run”.

I’ve tried to summarize a bit of how that happens in a couple of articles on my blog: Optimizing functions in the JIT and Runtime overview. Note that none of that is official information, so treat anything on my blog as “random person on the internet says…”.

Best regards

Thomas

1 Like

Thank you for the quick reply @tom and also for the great articles!
It looks like torch.jit.last_executed_optimized_graph() is what I have been looking for.
Thanks

Last I looked, graph_for was using that, too…