Slicing using tensor doesn't work in torchscript as expected

To reproduce:

Expected behaviour:

def func(obj_x, obj_ind):
    new_x = torch.zeros((20, 3, 7), device=obj_x.device, dtype=torch.float32)
    j = torch.zeros((1,), device=obj_x.device, dtype=torch.long)
    for i, idx in enumerate(obj_ind[1:], start=1):
        new_x[idx][j] = obj_x[i]
    return new_x

obj_x = torch.randn(10, 7)
obj_ind = torch.tensor(list(range(0, 20, 2)))
func(obj_x, obj_ind)
>> tensor([[[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000]],

        [[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000]],

        [[-0.2289, -1.1137,  0.2680,  0.0294,  0.3793, -0.1392, -0.7233],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000]],

        [[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000]],

        [[ 0.0367,  0.5694, -0.1834,  0.9413,  1.1738, -1.4721, -2.0817],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000]],
...

Using the torchscript:

@torch.jit.script
def func(obj_x, obj_ind):
    new_x = torch.zeros((20, 3, 7), device=obj_x.device, dtype=torch.float32)
    j = torch.zeros((1,), device=obj_x.device, dtype=torch.long)
    for i, idx in enumerate(obj_ind[1:], start=1):
        new_x[idx][j] = obj_x[i]
    return new_x

func(obj_x, obj_ind)
tensor([[[0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0.]],
...

Did I miss something?

I think you are hitting multiple issues.
The first one seems to be the “double indexing” which doesn’t assign the values to the original tensor.
Index new_x and the assignment should work:

new_x[idx, j] = obj_x[i]

However, even though new_x would contain values now, it seems that enumerate(..., start=idx) is not working correctly while scripting and the start value is ignored:

def fun():
    for i, idx in enumerate(torch.arange(5), start=1):
        print(i, idx)
fun()

@torch.jit.script
def fun():
    for i, idx in enumerate(torch.arange(5), start=1):
        print(i, idx)
fun()

Output:

1 tensor(0)
2 tensor(1)
3 tensor(2)
4 tensor(3)
5 tensor(4)

0 0
[ CPULongType{} ]
1 1
[ CPULongType{} ]
2 2
[ CPULongType{} ]
3 3
[ CPULongType{} ]
4 4
[ CPULongType{} ]

Would you mind creating an issue for this on GitHub?

Thanks for the response @ptrblck

At least now I know how to tackle the first issue. Also, I created the issue regarding the bug you’ve found enumerate(..., start=idx) is not working correctly while scripting · Issue #67142 · pytorch/pytorch · GitHub

Feel free to continue our discussion on Github.