To reverse a tensor in some dimension with masked info

Hi,

Is there a better way to reverse a tensor with a mask in some dimension? Currently, I do this:

def masked_reverse(x, pad=0.):
    mask = (x != pad).float()
    upper_tri = torch.triu(torch.ones(x.size(1), x.size(1)))
    lower_tri = upper_tri.transpose(0, 1) 
    ones = torch.ones(x.size())
    ones = ones.matmul(upper_tri) - 1
    ones = ones * (1 - mask)
    # obtain re-arranged tensor for reversing the indices
    rev_indices = (mask.matmul(lower_tri) - 1) * mask # reverse indices start with 0
    rearrange = rev_indices + ones
    rearrange_mat = torch.zeros(rearrange.size(0), rearrange.size(1), rearrange.size(1))
    rearrange_mat = rearrange_mat.scatter(2, rearrange.long().unsqueeze(2), 1) 
    xr = torch.bmm(rearrange_mat.long(), x.unsqueeze(2)).squeeze(-1)
    return xr
# Example input with 0's as padded tokens
inp=torch.LongTensor(4,10).random_(1,10)
inp[0][8:]=0 # input masked from position 8-9
inp[1][6:]=0 # input masked from position 6-9
inp[2][9:]=0 # input masked at position 9


>>> inp
tensor([[6, 5, 4, 4, 6, 4, 8, 5, 0, 0],
        [3, 9, 7, 4, 4, 7, 0, 0, 0, 0],
        [4, 9, 8, 1, 7, 3, 8, 3, 4, 0],
        [4, 5, 7, 9, 3, 1, 1, 5, 8, 9]])

#Reversed
>>> masked_reverse(inp)
tensor([[5, 8, 4, 6, 4, 4, 5, 6, 0, 0],
        [7, 4, 4, 7, 9, 3, 0, 0, 0, 0],
        [4, 3, 8, 3, 7, 1, 8, 9, 4, 0],
        [9, 8, 5, 1, 1, 3, 9, 7, 5, 4]])

Probably this is the only way to do it right now. But it is memory inefficient.