I have the following piece of code.
device = torch.device("cuda:0")
amt = 316377
depth, width, height = 1379, 280, 85
colors = torch.randn(amt, 3, dtype=torch.float32, device=device)
mask = torch.rand((amt,), device=device) < 0.95
above_amt = mask.sum()
x_idx = torch.randint(0, depth, (above_amt,), dtype=torch.long, device=device)
y_idx = torch.randint(0, width, (above_amt,), dtype=torch.long, device=device)
z_idx = torch.randint(0, height, (above_amt,), dtype=torch.long, device=device)
grid = torch.full((depth, width, height, 3), -1, device=device, dtype=torch.float32)
By doing the following assignment:
grid[x_idx, y_idx, z_idx] = colors[mask]
I was expecting grid[x_idx, y_idx, z_idx]
and colors[mask]
to have the same values. It seems this is not the case:
(colors[mask] != grid[x_idx, y_idx, z_idx]).any()
-------------------------------------------------
[Out] tensor(True, device='cuda:0')
Am I missing something? Shouldn’t the indexing work as am I expecting it to work?