Indexed assignment behaves unexpectedly

I have the following piece of code.

device = torch.device("cuda:0")

amt = 316377
depth, width, height = 1379, 280, 85

colors = torch.randn(amt, 3, dtype=torch.float32, device=device)
mask = torch.rand((amt,), device=device) < 0.95

above_amt = mask.sum()

x_idx = torch.randint(0, depth, (above_amt,), dtype=torch.long, device=device)
y_idx = torch.randint(0, width, (above_amt,), dtype=torch.long, device=device)
z_idx = torch.randint(0, height, (above_amt,), dtype=torch.long, device=device)

grid = torch.full((depth, width, height, 3), -1, device=device, dtype=torch.float32)

By doing the following assignment:

grid[x_idx, y_idx, z_idx] = colors[mask]

I was expecting grid[x_idx, y_idx, z_idx] and colors[mask] to have the same values. It seems this is not the case:

(colors[mask] != grid[x_idx, y_idx, z_idx]).any()
-------------------------------------------------
[Out] tensor(True, device='cuda:0')

Am I missing something? Shouldn’t the indexing work as am I expecting it to work?

The issue is caused by creating duplicates as seen here:

import torch

device = "cpu"

amt = 316377
depth, width, height = 1379, 280, 85

colors = torch.randn(amt, 3, dtype=torch.float32, device=device)
mask = torch.rand((amt,), device=device) < 0.95

above_amt = mask.sum()

x_idx = torch.randint(0, depth, (above_amt,), dtype=torch.long, device=device)
y_idx = torch.randint(0, width, (above_amt,), dtype=torch.long, device=device)
z_idx = torch.randint(0, height, (above_amt,), dtype=torch.long, device=device)

grid = torch.full((depth, width, height, 3), -1, device=device, dtype=torch.float32)
grid[x_idx, y_idx, z_idx] = colors[mask]
res = (colors[mask] != grid[x_idx, y_idx, z_idx])
idx = res.nonzero()
example_idx = idx[:, 0][0]
print(example_idx)
# tensor(161)

# verify mismatch
print(colors[mask][example_idx])
# tensor([1.7955, 1.4412, 0.5304])
print(grid[x_idx, y_idx, z_idx][example_idx])
# tensor([-0.0593,  0.2556, -0.4934])

# search for duplicates
index = torch.stack([x_idx, y_idx, z_idx], dim=1)
print((index == torch.stack((x_idx[example_idx], y_idx[example_idx], z_idx[example_idx])).unsqueeze(0)).all(dim=1).nonzero())
# tensor([[  161],
#         [20486]])

# check second index
print(colors[mask][20486])
# tensor([-0.0593,  0.2556, -0.4934])
print(grid[x_idx, y_idx, z_idx][example_idx])
# tensor([-0.0593,  0.2556, -0.4934])
print(grid[x_idx, y_idx, z_idx][20486])
# tensor([-0.0593,  0.2556, -0.4934])

This example shows that using x_idx, y_idx, z_idx creates the same index twice, at position 161 and 20486, both using:

print(torch.stack((x_idx[161], y_idx[161], z_idx[161])))
# tensor([1229,  133,   45])
print(torch.stack((x_idx[20486], y_idx[20486], z_idx[20486])))
# tensor([1229,  133,   45])

where the latter value will overwrite the former indexing.

You should note that indexing with the same indices could also be non-deterministic.

1 Like