I was wondering if this kind of indexing (+assignment) is supported in Torchscript. Given the code below:
def test(A, indx): A[indx.unbind(-1)] = 0 return A my_tensor = torch.ones(3,3, dtype=torch.int32); my_indx = torch.tensor([[0,0], [1,1], [1,2]]) print("PYTORCH"); res_std = test(my_tensor.clone(), my_indx) print(res_std) print("---------") print("TORCHSCRIPT") scripted_test = torch.jit.script(test) print(scripted_test.code) res_jit = scripted_test(my_tensor.clone(), my_indx) print(res_jit)
It generates the following output:
PYTORCH tensor([[0, 1, 1], [1, 0, 0], [1, 1, 1]], dtype=torch.int32) --------- TORCHSCRIPT def test(A: Tensor, indx: Tensor) -> Tensor: _0 = torch.tensor(torch.unbind(indx, -1), dtype=4, device=None, requires_grad=False) _1 = torch.tensor(0, dtype=ops.prim.dtype(A), device=ops.prim.device(A), requires_grad=False) _2 = annotate(List[Optional[Tensor]], [_0]) _3 = torch.index_put_(A, _2, _1, False) return A RuntimeError: The following operation failed in the TorchScript interpreter. Traceback of TorchScript (most recent call last): RuntimeError: Input must be of ints, floats, or bools, got Tensor
First of all, the error message doesn’t tell too much about what is happening. Am I using Torchscript in an unexpected way? Is it because of a bug? Is there some kind of limitation with torchscript indexing? I’ve not managed to find any related issue in pytorch’s github or in this forum. Is there any known workaround?
I think in general happens when slicing a tensor. The following function give the same error message:
def test(A, indx): return A[indx.unbind(-1)]
This code have been tested with pytorch 1.8.1 and 1.9 with the same result.
Any suggestions/comment is very much appreciated,
Thanks in advance!