Hi,
Is there a better way to reverse a tensor with a mask in some dimension? Currently, I do this:
def masked_reverse(x, pad=0.):
mask = (x != pad).float()
upper_tri = torch.triu(torch.ones(x.size(1), x.size(1)))
lower_tri = upper_tri.transpose(0, 1)
ones = torch.ones(x.size())
ones = ones.matmul(upper_tri) - 1
ones = ones * (1 - mask)
# obtain re-arranged tensor for reversing the indices
rev_indices = (mask.matmul(lower_tri) - 1) * mask # reverse indices start with 0
rearrange = rev_indices + ones
rearrange_mat = torch.zeros(rearrange.size(0), rearrange.size(1), rearrange.size(1))
rearrange_mat = rearrange_mat.scatter(2, rearrange.long().unsqueeze(2), 1)
xr = torch.bmm(rearrange_mat.long(), x.unsqueeze(2)).squeeze(-1)
return xr
# Example input with 0's as padded tokens
inp=torch.LongTensor(4,10).random_(1,10)
inp[0][8:]=0 # input masked from position 8-9
inp[1][6:]=0 # input masked from position 6-9
inp[2][9:]=0 # input masked at position 9
>>> inp
tensor([[6, 5, 4, 4, 6, 4, 8, 5, 0, 0],
[3, 9, 7, 4, 4, 7, 0, 0, 0, 0],
[4, 9, 8, 1, 7, 3, 8, 3, 4, 0],
[4, 5, 7, 9, 3, 1, 1, 5, 8, 9]])
#Reversed
>>> masked_reverse(inp)
tensor([[5, 8, 4, 6, 4, 4, 5, 6, 0, 0],
[7, 4, 4, 7, 9, 3, 0, 0, 0, 0],
[4, 3, 8, 3, 7, 1, 8, 9, 4, 0],
[9, 8, 5, 1, 1, 3, 9, 7, 5, 4]])