Mapping tensors without offloading to cpu

consider

a=torch.tensor([0, 1, 2, 3, 4], device='cuda')
b=torch.tensor([4, 5, 6, 7, 8], device='cuda')
c=torch.tensor([0, 0, 0, 1, 1, 3, 4], device='cuda')

how can I get d=torch.tensor([4, 4, 4, 5, 5, 7, 8], device='cuda')
without converting to regular list, or offloading to cpu?

(if they were regular lists, something as simple as

mapping = dict(zip(a, b))
d = map(mapping.get, c)

would’ve worked. )

torch.gather should work fine:

b = torch.tensor([4, 5, 6, 7, 8], device='cuda')
c = torch.tensor([0, 0, 0, 1, 1, 3, 4], device='cuda')

res = torch.gather(b, 0, c)
print(res)
> tensor([4, 4, 4, 5, 5, 7, 8], device='cuda:0')
1 Like

what if a is something else, like a=torch.tensor([20,10,30,40,0])?

I don’t know how a relates to the other two tensors. What would these values of a mean?

basically replacing elements of c using a dictionary with keys of a and values of b.
For example,

a=torch.tensor([1, 0, 3, -1, 4], device='cuda')
b=torch.tensor([4, 5, 6, 7, 8], device='cuda')
c=torch.tensor([0, 0, 0, 1, -1, 3, 4], device='cuda')

then

d = custom_map(a,b,c) # = torch.tensor([5, 5, 5, 4, 7, 6, 8], device='cuda')

is this clear enough?

You could index a with c, but would need to remove the negative values from a, as gather expects positive indices:

device = 'cuda'
a = torch.tensor([1, 0, 3, 2, 4], device=device)
b = torch.tensor([4, 5, 6, 7, 8], device=device)
c = torch.tensor([0, 0, 0, 1, -1, 3, 4], device=device)

res = torch.gather(b, 0, a[c])
print(res)
> tensor([5, 5, 5, 4, 8, 6, 8], device='cuda:0')

I’m also unsure about your expected result tensor, as the -1 in c would index a at the last position and b[4] would be 8 not 7.

a is not the list of indices. a is the old values and b is the new values; think of them as keys and values of a mapping dictionary.
for example all 0 (second element of a) in c should be mapped to 5 (second element of b).
if a,b and c were regular lists:

my_dict = dict(zip(a, b))
d = map(my_dict.get, c)

sorry if my explanation was not clear enough.