Hello.
I’m trying to replace exported Mixtral SparseMLP code with a custom op using fx.replace_pattern API. However, I’m seeing errors like: “SubgraphMatcher cannot be initialized with an pattern with dead code”. Because MOE code contains a for-loop, runtime checks added into the graph by torch.export.export API. Can we get SubgraphMatcher to support matching the graphs with runtime checks?
I attach the reproducer code. A quick and dirty WAR is to run remove_dead_code pass before the match.
from transformers import AutoConfig
import torch
from torch import fx, nn
import torch.nn.functional as F
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
from typing import Tuple, List
import copy
import fire
MODEL_ID = "mistralai/Mixtral-8x7B-v0.1"
@torch.library.custom_op("custom::sparse_mlp", mutates_args=())
def sparse_mlp(
tensor: torch.Tensor,
top_k: int,
w_gate: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
w3: torch.Tensor
) -> torch.Tensor:
batch_size, sequence_length, hidden_dim = tensor.size()
tensor = tensor.view(-1, hidden_dim)
router_logits = F.linear(tensor, w_gate)
routing_weights = F.softmax(router_logits, dim=1)
routing_weights, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
num_experts = router_logits.size(-1)
final_hidden_states = torch.zeros(batch_size * sequence_length, hidden_dim)
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=num_experts).permute(2, 1, 0)
# Loop over all available experts in the model and perform the computation on each expert
for expert_idx in range(num_experts):
#expert_layer = self.experts[expert_idx]
idx, top_x = torch.where(expert_mask[expert_idx])
expert_output = tensor[None, top_x].reshape(-1, hidden_dim)
expert_output = F.silu(F.linear(expert_output, w1[expert_idx, :, :]) * F.linear(expert_output, w3[expert_idx, :, :]))
expert_output = F.linear(expert_output, w2[expert_idx, :, :])
# Index the correct hidden states and compute the expert hidden state for
# the current expert. We need to make sure to multiply the output hidden
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
current_hidden_states = expert_output * routing_weights[top_x, idx, None]
# However `index_add_` only support torch tensors for indexing so we'll use
# the `top_x` tensor here.
final_hidden_states.index_add_(0, top_x, current_hidden_states)
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
return final_hidden_states
@sparse_mlp.register_fake
def _(
tensor: torch.Tensor,
top_k: int,
w_gate: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
w3: torch.Tensor
) -> torch.Tensor:
return tensor.clone()
class SparseMLP(nn.Module):
def __init__(self, model: nn.Module):
super().__init__()
self.top_k = model.top_k
self.w_gate = nn.Parameter(model.gate.weight)
self.w1 = nn.Parameter(torch.stack([ep.w1.weight for ep in model.experts])) # (num_experts, hidden_dim, ffn_dim)
self.w2 = nn.Parameter(torch.stack([ep.w2.weight for ep in model.experts])) # (num_experts, ffn_dim, hidden_dim)
self.w3 = nn.Parameter(torch.stack([ep.w3.weight for ep in model.experts])) # (num_experts, hidden_dim, hidden_dim)
tensor = torch.randn(1, 64, self.w_gate.size(1))
torch.library.opcheck(torch.ops.custom.sparse_mlp, (tensor, self.top_k, self.w_gate, self.w1, self.w2, self.w3))
def forward(self, tensor: torch.Tensor) -> torch.Tensor:
return torch.ops.custom.sparse_mlp(tensor, self.top_k, self.w_gate, self.w1, self.w2, self.w3)
def process(graph_module: fx.GraphModule) -> fx.GraphModule:
graph_module.graph.eliminate_dead_code()
graph_module.graph.lint()
# reset codegen if we modify output signature
graph_module.graph.set_codegen(fx.graph.CodeGen())
graph_module.recompile()
return graph_module
def take_first_pass(graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
for node in reversed(graph_module.graph.nodes):
if node.op == "output":
with graph_module.graph.inserting_after(node):
output = graph_module.graph.output(node.args[0][0], torch.Tensor)
node.replace_all_uses_with(output)
graph_module.graph.erase_node(node)
return process(graph_module)
def remove_dead_code_pass(graph_module: fx.GraphModule) -> fx.GraphModule:
for node in reversed(graph_module.graph.nodes):
if len(node.users) == 0 and node.op not in {"output", "placeholder"}:
print(f"erasing {node=}")
graph_module.graph.erase_node(node)
return process(graph_module)
@torch.inference_mode()
def main(seq_len=64, remove_dead_code=False, strict=False):
config = AutoConfig.from_pretrained(MODEL_ID)
model = MixtralSparseMoeBlock(config)
tensor = torch.randn(1, seq_len, config.hidden_size)
graph_module = torch.export.export(model, (tensor,), strict=strict).module()
graph_module = take_first_pass(graph_module)
if remove_dead_code:
graph_module = remove_dead_code_pass(graph_module)
# Pattern needs to be an independent GraphModule for fx.replace_pattern to work.
pattern = copy.deepcopy(graph_module)
replacement = torch.export.export(SparseMLP(model), (tensor,), strict=strict).module()
fx.replace_pattern(graph_module, pattern, replacement)
graph_module.print_readable()
print(f"runtime result: {graph_module(tensor)}")
if __name__ == "__main__":
fire.Fire(main)