Torchscript indexing with tuple/list of tensors

Hello everyone
I was wondering if this kind of indexing (+assignment) is supported in Torchscript. Given the code below:

def test(A, indx):
	A[indx.unbind(-1)] = 0
	return A

my_tensor = torch.ones(3,3, dtype=torch.int32);
my_indx = torch.tensor([[0,0], [1,1], [1,2]])

print("PYTORCH");
res_std = test(my_tensor.clone(), my_indx)
print(res_std)
print("---------")
print("TORCHSCRIPT")
scripted_test = torch.jit.script(test)
print(scripted_test.code)
res_jit = scripted_test(my_tensor.clone(), my_indx)
print(res_jit)

It generates the following output:

PYTORCH
tensor([[0, 1, 1],
        [1, 0, 0],
        [1, 1, 1]], dtype=torch.int32)
---------
TORCHSCRIPT
def test(A: Tensor,
    indx: Tensor) -> Tensor:
  _0 = torch.tensor(torch.unbind(indx, -1), dtype=4, device=None, requires_grad=False)
  _1 = torch.tensor(0, dtype=ops.prim.dtype(A), device=ops.prim.device(A), requires_grad=False)
  _2 = annotate(List[Optional[Tensor]], [_0])
  _3 = torch.index_put_(A, _2, _1, False)
  return A

RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
RuntimeError: Input must be of ints, floats, or bools, got Tensor

First of all, the error message doesn’t tell too much about what is happening. Am I using Torchscript in an unexpected way? Is it because of a bug? Is there some kind of limitation with torchscript indexing? I’ve not managed to find any related issue in pytorch’s github or in this forum. Is there any known workaround?

I think in general happens when slicing a tensor. The following function give the same error message:

def test(A, indx):
	return A[indx.unbind(-1)]

This code have been tested with pytorch 1.8.1 and 1.9 with the same result.

Any suggestions/comment is very much appreciated,
Thanks in advance!