My rough understanding of the fusing of torch.compile is to compile the following code
def func(x):
y = f1(x)
z = f2(y)
return z
to a C++ code roughly as follows:
for(int i=0;i<len(x);i++) {
tmp_scalar = f1(x[i])
z[i] = f2(tmp_scalar)
}
So it keeps everything about y in register and avoid writing to memory.
However if f1 and f2 are big functions, it expands everything inside. Is there a way to prevent the expansion of f1 and f2, while still keeping the fusing behavior, for the purpose of reducing compilation time? In C++, f1 and f2 should be compiled into a non-inlined scalar function, and I would just like to do another compilation to combine f1 and f2 and then loop over i. If I understand correctly, graph break does not try to fuse the separated parts, so it effectively writes the full y vector in memory, which is what I want to avoid.