Reorder non-zero values in a 2d matrix

I’m looking for an elegant pytorch way to take a 2d matrix and move all non-zero values to the beginning of each row, while keeping their order. My current solution is super hacky, there must be something nicer?

For example:
input:
[[1, 5, 0, 2, 0, 0],
[0, 0, 9, 7, 0, 0]

output:
[[1, 5, 2, 0, 0, 0],
[9, 7, 0, 0, 0, 0]

My solution

index = torch.sort((x != 0) * 1, dim=1, descending=True)[1]
y = x.gather(1, index)

Nice one! Thanks for that.

Actually there’s an issue: I find the sort order of identical items to be non-deterministic (probably on GPU only). For example on the row (part of a 2-D matrix)

[1023, 8, 1023, 1023, 1023, 1023, 1023, 1023, 1023, 1023, 1023, 1023,
606, 1023, 1023, 1023, 1023, 1023, 408, 1023, 1023, 1023, 427, 1023]

I got the result

[ 408, 427, 8, 606, 1023, 1023, 1023, 1023, 1023, 1023, 1023, 1023,
1023, 1023, 1023, 1023, 1023, 1023, 1023, 1023, 1023, 1023, 1023, 1023]

(1023 acts as 0 here).

Maybe…

# 1. get nonzero mask
s_mask = x != 0

# 2. get head mask
d_mask = torch.sum(s_mask, dim=1, keepdims=True) > torch.arange(x.shape[1])

# copy
y = torch.zeros_like(x) 
y[d_mask] = x[s_mask]
1 Like