How to achieve Torch-MLIR compatibility with Dynamo backend?

Consider the following simple module that only does a matrix multiplication and a torch Dynamo backend called toy_backend.

class TestModule(nn.Module):
    def forward(a, w):
        return torch.mm(a, w)

def toy_backend(gm, inputs):
    return gm.forward

c = torch.compile(TestModule(), backend=toy_backend)

Note that gm.forward will return a tuple even though the original nn.Module returns a single value (see [1]).
This is not a problem in the current example. However, if instead of simply returning gm.forward we want to invoke the Torch-MLIR compile function on gm this does not work.
This is the case because Torch-MLIR does not support single element tuple returns.

Is it possible to change the GraphModule to return a single element (not a tuple)?

[1] Note the output of gm.print_readable() clearly shows a tuple being returned and performing gm.forward also returns a tuple:

def forward(self, L_a_ : torch.Tensor, L_w_ : torch.Tensor):
    l_a_ = L_a_
    l_w_ = L_w_

    mm = torch.mm(l_a_, l_w_);  l_a_ = l_w_ = None
    return (mm,)

Not sure if this is a good way to do it, but the code below works. It simply deletes the output node and adds a new one after which it recompiles.

before_output_node, output_node = list(gm.graph.nodes)[-2:]
gm.graph.erase_node(output_node)
gm.graph.output(before_output_node)
gm.recompile()
1 Like