Issue: topk>16 raises Error on MPS

a2mps = torch.topk(input=a.to(“mps”), k=20) # not working at all
RuntimeError: Currently topk on mps works only for k<=16

a1mps = torch.topk(input=a.to(“mps”), k=15) # working fine

The corresponding issue is MPS: Add support for TopK (k>16) on M1 GPU · Issue #78915 · pytorch/pytorch · GitHub
When is the fix expected?

print (torch.version): 1.14.0.dev20221022

1 Like