Jitting fails on mps device

def f(x): return x.to(torch.uint8)
x = torch.tensor([-1.0, -2.0], device=torch.device("mps"))

print(f(x))
print(torch.compile(f)(x))

torch._dynamo.exc.BackendCompilerFailed: backend=‘inductor’ raised:
TypeError: ‘NoneType’ object is not callable

The above fails when I select to compile using Metal device.
However, on the cpu it passes using OpenMP.

device info: Apple M1 Max using torch 2.1.0

the default backend inductor is not supported on MPS so you’ll need to use torch.compile(m, backend="aot_eager")

thanks. although, it defeats the purpose of torch.compile for me, if it’s still eager. my goal was to see the output device-specific kernel.

is there a github tracker for mps support in compile mode? maybe i could contribute.

@kulinseth was working on this I believe, my understanding is MPS would need a Triton backend things should just work with Inductor although the details probably matter