To reproduce:
Expected behaviour:
def func(obj_x, obj_ind):
new_x = torch.zeros((20, 3, 7), device=obj_x.device, dtype=torch.float32)
j = torch.zeros((1,), device=obj_x.device, dtype=torch.long)
for i, idx in enumerate(obj_ind[1:], start=1):
new_x[idx][j] = obj_x[i]
return new_x
obj_x = torch.randn(10, 7)
obj_ind = torch.tensor(list(range(0, 20, 2)))
func(obj_x, obj_ind)
>> tensor([[[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],
[[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],
[[-0.2289, -1.1137, 0.2680, 0.0294, 0.3793, -0.1392, -0.7233],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],
[[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],
[[ 0.0367, 0.5694, -0.1834, 0.9413, 1.1738, -1.4721, -2.0817],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],
...
Using the torchscript:
@torch.jit.script
def func(obj_x, obj_ind):
new_x = torch.zeros((20, 3, 7), device=obj_x.device, dtype=torch.float32)
j = torch.zeros((1,), device=obj_x.device, dtype=torch.long)
for i, idx in enumerate(obj_ind[1:], start=1):
new_x[idx][j] = obj_x[i]
return new_x
func(obj_x, obj_ind)
tensor([[[0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0.]],
[[0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0.]],
[[0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0.]],
[[0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0.]],
...
Did I miss something?