I’m having trouble understanding why this piece of code does not work as expected:
a = torch.rand(1000)
idx1 = torch.arange(500, 600)
idx2 = torch.rand(100) < 0.5
a[idx1][idx2] = 1.0
print(a[idx1][idx2])
I expect the result to be all ones, but the tensor is unchanged. If instead, I index only once like so:
a[idx1] = 1.0
print(a[idx1])
Then, as expected I get all ones. Why does the double index not work and how can I achieve this behavior?