Access a tensor entry programatically

hi, I may be missing something, but assume you have a tensor v of shape, say (a,b,c,d), and another tensor say ind=torch.Tensor([3,2,1,0]).

How can I access v[3,2,1,0] programatically, i.e. by exploiting the variables v and ind ?

edit: I know that v[tuple(ind)] works, but I want this to be jitable, and this solution apparently is not

You can use this ↓ as your index

ind.long().split(1)
# Example
a, b, c, d = 10, 20, 30, 40

v = torch.rand(a, b, c, d)
ind=torch.Tensor([3,2,1,0])

print(torch.all(v[ind.long().split(1)] == v[3, 2, 1, 0]))

# Output:
tensor(True)

awesome. It works, but unfortunately not in a torch.jit.script environment.

@torch.jit.script
def access(v, index):
  return v[index.split(1)]

v = torch.randn((3, 5, 5))
index = torch.tensor([0,2,2])

print(index)
print(v[0, 2, 2])
print(v[index.split(1)])
print(access(v, index))

any idea ?

If you use the @torch.jit.export decorator it works.

@torch.jit.export
def access(v, index):
  return v[index.split(1)]

v = torch.randn((3, 5, 5))
index = torch.tensor([0,2,2])

print(index)
print(v[0, 2, 2])
print(v[index.split(1)])
print(access(v, index))

Here is the documentation. But I do not have that much experience with jit so this might not be what you are looking for.

thanks a lot. This will be called from forward unfortunately. I will investigate

I got it to work a little more with @torch.jit.script.

@torch.jit.script
def access(v, index: torch.Tensor):
  return v.index(index.split(1))
1 Like

awesome, thanks it indeed works fine !!!

1 Like