Specify axis for one_hot encoding

If this assessment of the problem is correct, then you could also do the following[1].

def first_nonzero(x, axis=0):
    nonz = (x > 0)
    return ((nonz.cumsum(axis) == 1) & nonz).max(axis)

A = torch.tensor(
       [[[0, 1],
         [0, 1]],

        [[0, 0],
         [2, 2]],

        [[0, 0],
         [1, 0]]]
)

_, inds = first_nonzero(A)
T = A[inds, [[0,0],[1,1]], [[0,1],[0,1]]]
print(T)
# Output
tensor([[0, 1],
        [2, 1]])

[1] Function taken from here.

1 Like