If this assessment of the problem is correct, then you could also do the following[1].
def first_nonzero(x, axis=0):
nonz = (x > 0)
return ((nonz.cumsum(axis) == 1) & nonz).max(axis)
A = torch.tensor(
[[[0, 1],
[0, 1]],
[[0, 0],
[2, 2]],
[[0, 0],
[1, 0]]]
)
_, inds = first_nonzero(A)
T = A[inds, [[0,0],[1,1]], [[0,1],[0,1]]]
print(T)
# Output
tensor([[0, 1],
[2, 1]])
[1] Function taken from here.