amaleki
(Amir Maleki)
February 3, 2021, 4:25am
1
consider
a=torch.tensor([0, 1, 2, 3, 4], device='cuda')
b=torch.tensor([4, 5, 6, 7, 8], device='cuda')
c=torch.tensor([0, 0, 0, 1, 1, 3, 4], device='cuda')
how can I get d=torch.tensor([4, 4, 4, 5, 5, 7, 8], device='cuda')
without converting to regular list, or offloading to cpu?
(if they were regular lists, something as simple as
mapping = dict(zip(a, b))
d = map(mapping.get, c)
would’ve worked. )
torch.gather
should work fine:
b = torch.tensor([4, 5, 6, 7, 8], device='cuda')
c = torch.tensor([0, 0, 0, 1, 1, 3, 4], device='cuda')
res = torch.gather(b, 0, c)
print(res)
> tensor([4, 4, 4, 5, 5, 7, 8], device='cuda:0')
1 Like
amaleki
(Amir Maleki)
February 3, 2021, 6:06am
3
what if a
is something else, like a=torch.tensor([20,10,30,40,0])
?
I don’t know how a
relates to the other two tensors. What would these values of a
mean?
amaleki
(Amir Maleki)
February 3, 2021, 7:27am
5
basically replacing elements of c
using a dictionary with keys of a
and values of b
.
For example,
a=torch.tensor([1, 0, 3, -1, 4], device='cuda')
b=torch.tensor([4, 5, 6, 7, 8], device='cuda')
c=torch.tensor([0, 0, 0, 1, -1, 3, 4], device='cuda')
then
d = custom_map(a,b,c) # = torch.tensor([5, 5, 5, 4, 7, 6, 8], device='cuda')
is this clear enough?
You could index a
with c
, but would need to remove the negative values from a
, as gather
expects positive indices:
device = 'cuda'
a = torch.tensor([1, 0, 3, 2, 4], device=device)
b = torch.tensor([4, 5, 6, 7, 8], device=device)
c = torch.tensor([0, 0, 0, 1, -1, 3, 4], device=device)
res = torch.gather(b, 0, a[c])
print(res)
> tensor([5, 5, 5, 4, 8, 6, 8], device='cuda:0')
I’m also unsure about your expected result tensor, as the -1
in c
would index a
at the last position and b[4]
would be 8
not 7
.
amaleki
(Amir Maleki)
February 3, 2021, 5:43pm
7
a
is not the list of indices. a
is the old values and b
is the new values; think of them as keys and values of a mapping dictionary.
for example all 0 (second element of a) in c
should be mapped to 5 (second element of b
).
if a,b
and c
were regular lists:
my_dict = dict(zip(a, b))
d = map(my_dict.get, c)
sorry if my explanation was not clear enough.