Hi,
I have a question regarding vectorized indexing to assign entries from the source tensor to the target when there are collisions in the target indexes.
Specifically, considering the following example:
import torch
device = torch.device("cpu")
source_tensor = torch.tensor(
[1.23, 2.68, 9.37, 5.84, 6.87, 5.21, 7.46, 2.33, 8.95],
dtype=torch.float32,
device=device,
)
indexes = torch.tensor(
[0, 0, 1, 0, 0, 1, 0, 1, 1],
dtype=torch.long,
device=device,
)
target_tensor = torch.zeros(2, dtype=torch.float32, device=device)
target_tensor[indexes] = source_tensor
print(target_tensor)
We have multiple entries from the source tensor got assigned to the same location (0 and 1) in the target tensor. In this example, the number of points are still small. But in my practical case, the number of points can be 10K+.
-
The first issue I encountered is that I got different results if I change device to
torch.device("cuda")
. It means that the ways it handles collisions are different between GPU and CPU. How do I make it consistent? -
Same as the issue mentioned above, in this example, the CPU one seems to be assigning the entries sequentially (at least on my desktop), so one would get
tensor([7.4600, 8.9500])
as the result (they are the last ones got assigned to location 0 and 1 separately). But when I did it in my practical case (> 10K points), what I encountered was that it’s also not done sequentially even using CPU. Is there a way to enforce the sequential behavior on both GPU and CPU? Because then we can potentially pre-sort the tensor if we want to always fetch the minimum when collisions happen? One might suggest that I can simply remove the collisions, but I don’t know a way to do that through vectorized operations, and that makes the process super slow.
Another potential solution is the Tensor.index_reduce_()
. By using this function, we can specify the rules like amin
, etc. However, this function won’t give me the inverse index. i.e. I can’t infer which entry from the source_tensor
got selected. Is there any suggestion how do I get the inverse indexes?
Thanks