SubgraphMatcher can't be initialized with MOE pattern

Hello.

I’m trying to replace exported Mixtral SparseMLP code with a custom op using fx.replace_pattern API. However, I’m seeing errors like: “SubgraphMatcher cannot be initialized with an pattern with dead code”. Because MOE code contains a for-loop, runtime checks added into the graph by torch.export.export API. Can we get SubgraphMatcher to support matching the graphs with runtime checks?

I attach the reproducer code. A quick and dirty WAR is to run remove_dead_code pass before the match.

from transformers import AutoConfig
import torch
from torch import fx, nn
import torch.nn.functional as F
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
from typing import Tuple, List
import copy
import fire

MODEL_ID = "mistralai/Mixtral-8x7B-v0.1"

@torch.library.custom_op("custom::sparse_mlp", mutates_args=())
def sparse_mlp(
    tensor: torch.Tensor,
    top_k: int,
    w_gate: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    w3: torch.Tensor
) -> torch.Tensor:
    batch_size, sequence_length, hidden_dim = tensor.size()
    tensor = tensor.view(-1, hidden_dim)
    router_logits = F.linear(tensor, w_gate)
    routing_weights = F.softmax(router_logits, dim=1)
    routing_weights, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
    routing_weights /= routing_weights.sum(dim=-1, keepdim=True)

    num_experts = router_logits.size(-1)
    final_hidden_states = torch.zeros(batch_size * sequence_length, hidden_dim)
    expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=num_experts).permute(2, 1, 0)

    # Loop over all available experts in the model and perform the computation on each expert
    for expert_idx in range(num_experts):
        #expert_layer = self.experts[expert_idx]
        idx, top_x = torch.where(expert_mask[expert_idx])
        expert_output = tensor[None, top_x].reshape(-1, hidden_dim)
        expert_output = F.silu(F.linear(expert_output, w1[expert_idx, :, :]) * F.linear(expert_output, w3[expert_idx, :, :]))
        expert_output = F.linear(expert_output, w2[expert_idx, :, :])

        # Index the correct hidden states and compute the expert hidden state for
        # the current expert. We need to make sure to multiply the output hidden
        # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
        current_hidden_states = expert_output * routing_weights[top_x, idx, None]

        # However `index_add_` only support torch tensors for indexing so we'll use
        # the `top_x` tensor here.
        final_hidden_states.index_add_(0, top_x, current_hidden_states)
    final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
    return final_hidden_states

@sparse_mlp.register_fake
def _(
    tensor: torch.Tensor,
    top_k: int,
    w_gate: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    w3: torch.Tensor
) -> torch.Tensor:
    return tensor.clone()

class SparseMLP(nn.Module):
    def __init__(self, model: nn.Module):
        super().__init__()
        self.top_k = model.top_k
        self.w_gate = nn.Parameter(model.gate.weight)
        self.w1 = nn.Parameter(torch.stack([ep.w1.weight for ep in model.experts])) # (num_experts, hidden_dim, ffn_dim)
        self.w2 = nn.Parameter(torch.stack([ep.w2.weight for ep in model.experts])) # (num_experts, ffn_dim, hidden_dim)
        self.w3 = nn.Parameter(torch.stack([ep.w3.weight for ep in model.experts])) # (num_experts, hidden_dim, hidden_dim)
        tensor = torch.randn(1, 64, self.w_gate.size(1))
        torch.library.opcheck(torch.ops.custom.sparse_mlp, (tensor, self.top_k, self.w_gate, self.w1, self.w2, self.w3))

    def forward(self, tensor: torch.Tensor) -> torch.Tensor:
        return torch.ops.custom.sparse_mlp(tensor, self.top_k, self.w_gate, self.w1, self.w2, self.w3)

def process(graph_module: fx.GraphModule) -> fx.GraphModule:
    graph_module.graph.eliminate_dead_code()
    graph_module.graph.lint()
    # reset codegen if we modify output signature
    graph_module.graph.set_codegen(fx.graph.CodeGen())
    graph_module.recompile()

    return graph_module

def take_first_pass(graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
    for node in reversed(graph_module.graph.nodes):
        if node.op == "output":
            with graph_module.graph.inserting_after(node):
                output = graph_module.graph.output(node.args[0][0], torch.Tensor)
                node.replace_all_uses_with(output)
                graph_module.graph.erase_node(node)

    return process(graph_module)

def remove_dead_code_pass(graph_module: fx.GraphModule) -> fx.GraphModule:
    for node in reversed(graph_module.graph.nodes):
        if len(node.users) == 0 and node.op not in {"output", "placeholder"}:
            print(f"erasing {node=}")
            graph_module.graph.erase_node(node)

    return process(graph_module)

@torch.inference_mode()
def main(seq_len=64, remove_dead_code=False, strict=False):
    config = AutoConfig.from_pretrained(MODEL_ID)
    model = MixtralSparseMoeBlock(config)
    tensor = torch.randn(1, seq_len, config.hidden_size)
    graph_module = torch.export.export(model, (tensor,), strict=strict).module()
    graph_module = take_first_pass(graph_module)

    if remove_dead_code:
        graph_module = remove_dead_code_pass(graph_module)

    # Pattern needs to be an independent GraphModule for fx.replace_pattern to work.
    pattern = copy.deepcopy(graph_module)

    replacement =  torch.export.export(SparseMLP(model), (tensor,), strict=strict).module()

    fx.replace_pattern(graph_module, pattern, replacement)

    graph_module.print_readable()

    print(f"runtime result: {graph_module(tensor)}")

if __name__ == "__main__":
    fire.Fire(main)