Bug when indexing tensors using an MPS device

Hi there.

I detected this possible bug, but before raising an issue in the repo I’m asking here.

When executing the following code for a CPU or a CUDA device:

device = torch.device("cpu")

x = torch.tensor([[1, 0, 2, 0, -100, -100, -100],
                  [1, 0, 2, 0, -100, -100, -100]])
mask = torch.tensor([[True, True, True, True, False, False, False],
                     [True, True, True, True, False, False, False]])

x = x.to(device)
mask = mask.to(device)

print(x[mask])

x is tensor([1, 0, 2, 0, 1, 0, 2, 0]), what it’s correct.

But, when using MPS as a device, the result is tensor([-100, 0, 2, 0, 1, 0, 2, 0], device='mps:0'), which is not correct.

This does not happen when x is 1D nor when the mask is in the CPU. But it also happens using torch.masked_select.

If this is the expected behavior, could someone explain it to me, please? And is there some workaround?

I’m using PyTorch version 2.2.2 and a MacBook Pro with an M2 Max and macOS Sonoma 14.4.1.

Thanks.