Hi,
I’m interested in PyTorch’s operator fusion (still relatively new, so please feel free to correct me). I have been reading the Scheduler’s source code and noticed a strange case where it seems to consider the fusability of a pair of nodes node1, node2
in the normal order: self.can_fuse(node1, node2)
but also the reverse order: self.can_fuse(node2, node1)
. Specifically, I’m referring to this code here:
def check_all_pairs(nodes):
for node1_index, node1 in enumerate(nodes):
for node2 in nodes[node1_index + 1 :]:
key = (node1, node2)
if key in seen:
continue
seen.add(key)
if self.can_fuse(node1, node2):
possible_fusions.append(key)
elif (node2.is_template() or node2.is_foreach()) and self.can_fuse(
node2, node1
):
# foreach fusions and epilogue fusions are order dependent
possible_fusions.append((node2, node1))
I inspected this function on some trivial functions via torch.compile()
and it seems that nodes
is typically a set containing a graph node n
and all the nodes it outputs to m_1,...m_k
(it’s probably not that simple, so feel free to correct me there). So we have a set (n
, m_1
, … m_k
). So in the self.can_fuse(node1, node2)
case we have:
- Check whether we can fuse
n
with somem_i+1
(vertical fusion?) - Check whether we can fuse
m_i
withm_i+1
(horizontal fusion?)
for somei
in1...k
.
Now my question is, when is self.can_fuse(node2, node1)
meaningful? From my understanding this could be wrong as fusing operations in reverse can affect the overall function’s correctness.
Would appreciate someone’s insight here. Many thanks!