Strange behavior with PyTorch when setting tensor values by index

I have a non-zero tensor
input with shape torch.Size([128, 10]) and a zero tensor with identical size called z. I’m trying to access z by index and set it accordingly as:

N, C = input.shape
z = torch.zeros(input.shape)

for n in range(N):
    for c in range(C):
        if (c == target[n]):
            z[n, c] = input[n, c]
        else:
            z[n, c] = -1 * input[n, c]

where target is another non-zero tensor with shape torch.Size([128]). When I run this, only the very fist element of z gets updated and the rest are all zeroes. What am I doing wrong?

This code works for me and all values in z are set:

input = torch.randn(128, 10)

N, C = input.shape
z = torch.zeros(input.shape)
target = torch.randint(0, C, (N,))

for n in range(N):
    for c in range(C):
        if (c == target[n]):
            z[n, c] = input[n, c]
        else:
            z[n, c] = -1 * input[n, c]

(z==0.).float().sum()
# tensor(0.)

Hmm, strange, I wonder what I am doing wrong. Maybe it is related to the allocation of input/target to MPS since I am running it on Macbook Pro with M1.

That could be the case as I’m running on a plain Linux node. In case you are seeing the issue in the latest nightly, could you create an issue on GitHub, please?

1 Like

This has been fixed in [MPS] Remove incorrect asserts from `Copy.mm` (#86184) · pytorch/pytorch@8da704c · GitHub.