I appreciate it if somebody helps me to check a tensor to be a valid indicator function.
Meaning that all elements of each row are zero beside one
So A is valid and B and C are not valid:
A = torch.tensor([[1, 0, 0], [0, 1, 0])
B = torch.tensor([[1, 0, 0], [0, 0, 0])
C = torch.tensor([[1, 0, 0], [1, 0, 1])
But we are looking for results that satisfy the condition in all rows which means all values should be True. To do this we add .all() at the end which tells us that all values are True or False.