I have a tensor A
of shape m*n
and a long tensor B
of shape m*1
, each element in B
is in [0, n-1]
indicating the last unmasked element in each row of A
, and then I can return a mask of m*n
for A
based on B
.
For example, the shape of A
is 4*3
, and B
is [[2], [1], [0], [1]]
, then I want to return a mask as follows,
[[1, 1, 1],
[1, 1, 0],
[1, 0, 0],
[1, 1, 0]]
Is there any similar API provided by PyTorch?