aliutkus
(Antoine Liutkus)
1
hi, I may be missing something, but assume you have a tensor v
of shape, say (a,b,c,d)
, and another tensor say ind=torch.Tensor([3,2,1,0])
.
How can I access v[3,2,1,0]
programatically, i.e. by exploiting the variables v
and ind
?
edit: I know that v[tuple(ind)]
works, but I want this to be jitable, and this solution apparently is not
You can use this ↓ as your index
ind.long().split(1)
# Example
a, b, c, d = 10, 20, 30, 40
v = torch.rand(a, b, c, d)
ind=torch.Tensor([3,2,1,0])
print(torch.all(v[ind.long().split(1)] == v[3, 2, 1, 0]))
# Output:
tensor(True)
aliutkus
(Antoine Liutkus)
3
awesome. It works, but unfortunately not in a torch.jit.script
environment.
@torch.jit.script
def access(v, index):
return v[index.split(1)]
v = torch.randn((3, 5, 5))
index = torch.tensor([0,2,2])
print(index)
print(v[0, 2, 2])
print(v[index.split(1)])
print(access(v, index))
any idea ?
If you use the @torch.jit.export
decorator it works.
@torch.jit.export
def access(v, index):
return v[index.split(1)]
v = torch.randn((3, 5, 5))
index = torch.tensor([0,2,2])
print(index)
print(v[0, 2, 2])
print(v[index.split(1)])
print(access(v, index))
Here is the documentation. But I do not have that much experience with jit
so this might not be what you are looking for.
aliutkus
(Antoine Liutkus)
5
thanks a lot. This will be called from forward unfortunately. I will investigate
I got it to work a little more with @torch.jit.script
.
@torch.jit.script
def access(v, index: torch.Tensor):
return v.index(index.split(1))
1 Like
aliutkus
(Antoine Liutkus)
7
awesome, thanks it indeed works fine !!!
1 Like