Hello everyone
I was wondering if this kind of indexing (+assignment) is supported in Torchscript. Given the code below:
def test(A, indx):
A[indx.unbind(-1)] = 0
return A
my_tensor = torch.ones(3,3, dtype=torch.int32);
my_indx = torch.tensor([[0,0], [1,1], [1,2]])
print("PYTORCH");
res_std = test(my_tensor.clone(), my_indx)
print(res_std)
print("---------")
print("TORCHSCRIPT")
scripted_test = torch.jit.script(test)
print(scripted_test.code)
res_jit = scripted_test(my_tensor.clone(), my_indx)
print(res_jit)
It generates the following output:
PYTORCH
tensor([[0, 1, 1],
[1, 0, 0],
[1, 1, 1]], dtype=torch.int32)
---------
TORCHSCRIPT
def test(A: Tensor,
indx: Tensor) -> Tensor:
_0 = torch.tensor(torch.unbind(indx, -1), dtype=4, device=None, requires_grad=False)
_1 = torch.tensor(0, dtype=ops.prim.dtype(A), device=ops.prim.device(A), requires_grad=False)
_2 = annotate(List[Optional[Tensor]], [_0])
_3 = torch.index_put_(A, _2, _1, False)
return A
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
RuntimeError: Input must be of ints, floats, or bools, got Tensor
First of all, the error message doesn’t tell too much about what is happening. Am I using Torchscript in an unexpected way? Is it because of a bug? Is there some kind of limitation with torchscript indexing? I’ve not managed to find any related issue in pytorch’s github or in this forum. Is there any known workaround?
I think in general happens when slicing a tensor. The following function give the same error message:
def test(A, indx):
return A[indx.unbind(-1)]
This code have been tested with pytorch 1.8.1 and 1.9 with the same result.
Any suggestions/comment is very much appreciated,
Thanks in advance!