Consider the following simple module that only does a matrix multiplication and a torch Dynamo backend called toy_backend
.
class TestModule(nn.Module):
def forward(a, w):
return torch.mm(a, w)
def toy_backend(gm, inputs):
return gm.forward
c = torch.compile(TestModule(), backend=toy_backend)
Note that gm.forward
will return a tuple even though the original nn.Module
returns a single value (see [1]).
This is not a problem in the current example. However, if instead of simply returning gm.forward
we want to invoke the Torch-MLIR compile
function on gm
this does not work.
This is the case because Torch-MLIR does not support single element tuple returns.
Is it possible to change the GraphModule to return a single element (not a tuple)?
[1] Note the output of gm.print_readable()
clearly shows a tuple being returned and performing gm.forward also returns a tuple:
def forward(self, L_a_ : torch.Tensor, L_w_ : torch.Tensor):
l_a_ = L_a_
l_w_ = L_w_
mm = torch.mm(l_a_, l_w_); l_a_ = l_w_ = None
return (mm,)