Reverse Fusion of Node Pairs in Scheduler

Hi,

I’m interested in PyTorch’s operator fusion (still relatively new, so please feel free to correct me). I have been reading the Scheduler’s source code and noticed a strange case where it seems to consider the fusability of a pair of nodes node1, node2 in the normal order: self.can_fuse(node1, node2) but also the reverse order: self.can_fuse(node2, node1). Specifically, I’m referring to this code here:

def check_all_pairs(nodes):
    for node1_index, node1 in enumerate(nodes):
        for node2 in nodes[node1_index + 1 :]:
            key = (node1, node2)
            if key in seen:
                continue
            seen.add(key)

            if self.can_fuse(node1, node2):
                possible_fusions.append(key)
            elif (node2.is_template() or node2.is_foreach()) and self.can_fuse(
                node2, node1
            ):
                # foreach fusions and epilogue fusions are order dependent
                possible_fusions.append((node2, node1))

source

I inspected this function on some trivial functions via torch.compile() and it seems that nodes is typically a set containing a graph node n and all the nodes it outputs to m_1,...m_k (it’s probably not that simple, so feel free to correct me there). So we have a set (n, m_1, … m_k). So in the self.can_fuse(node1, node2) case we have:

  1. Check whether we can fuse n with some m_i+1 (vertical fusion?)
  2. Check whether we can fuse m_i with m_i+1 (horizontal fusion?)
    for some i in 1...k.

Now my question is, when is self.can_fuse(node2, node1) meaningful? From my understanding this could be wrong as fusing operations in reverse can affect the overall function’s correctness.

Would appreciate someone’s insight here. Many thanks!