Check torch tensor has only one element of one

I appreciate it if somebody helps me to check a tensor to be a valid indicator function.
Meaning that all elements of each row are zero beside one
So A is valid and B and C are not valid:

A = torch.tensor([[1, 0, 0], [0, 1, 0])
B = torch.tensor([[1, 0, 0], [0, 0, 0])
C = torch.tensor([[1, 0, 0], [1, 0, 1])

Hi,

I think this can help:

A = torch.tensor([[1, 0, 0], [0, 1, 0]])
B = torch.tensor([[1, 0, 0], [0, 0, 0]])
C = torch.tensor([[1, 0, 0], [1, 0, 1]])

(torch.sum(C, dim=1) <= 1).all()

Actually, torch.sum(tensor, dim) reduces sum over specified dimension. So, for instance, for C,

torch.sum(C, dim=1)  # output: tensor([1, 2])

Which means we have two 1s in second row of tensor C. Then by adding <=1 it will convert it to a boolean

torch.sum(C, dim=1) <= 1  # output: tensor([ True, False])

But we are looking for results that satisfy the condition in all rows which means all values should be True. To do this we add .all() at the end which tells us that all values are True or False.

Bests