Hi,
I would like to do something like this:
a = torch.randint(0, 5, (6,))
inds = torch.randint(0, 7, (5,))
for i, ind in enumerate(inds):
a[a==ind] = i
The for loop method is likely to be slow, is there better method to do this please ?
Hi,
I would like to do something like this:
a = torch.randint(0, 5, (6,))
inds = torch.randint(0, 7, (5,))
for i, ind in enumerate(inds):
a[a==ind] = i
The for loop method is likely to be slow, is there better method to do this please ?
In your loop you are assigning the “last” i
value to a
in case inds
contains duplicated values and are also overwriting already replaced values.
E.g. assume that inds
contains [1, 0]
: if a
contains any 1s
, they would be replaces with a 0
and in the next iteration all 0s
would be replaced with 1s
again.
Given that, I think a “vectorized” approach might be tricky, as it would not consider this sequential way.
Nevertheless, here is a potentially faster approach, which will not yield the same results, if values are not unique:
input = torch.randint(0, 5, (6,))
a = input.clone()
b = input.clone()
inds = torch.randint(0, 7, (5,))
for i, ind in enumerate(inds):
a[a==ind] = i
idx = (b.unsqueeze(0) == inds.unsqueeze(1)).nonzero()
b[idx[:, 1]] = idx[:, 0]
Hi,
It is my fault that I did not describe correctly. What I would like to do is like this:
For example, the inputs are:
ind = torch.tensor([3,5,7,8])
a = torch.tensor([5,5,5,8,8,3,7,3,7])
Here values in ind
is unique, but values in a
is not unique.
And the expected output is like this:
out = torch.tensor([1,1,1,3,3,0,2,0,2])
I did not mean to do this in sequential way, my implementation above is wrong. Do I have a good way to do this please?
Yes, in that case my approach would work and you can compare the results as:
inds = torch.tensor([3,5,7,8])
input = torch.tensor([5,5,5,8,8,3,7,3,7])
a = input.clone()
b = input.clone()
for i, ind in enumerate(inds):
a[a==ind] = i
idx = (b.unsqueeze(0) == inds.unsqueeze(1)).nonzero()
b[idx[:, 1]] = idx[:, 0]
(a == b).all()
> tensor(True)