@AreTor There are many flavours of combinations
e.g
a=[0.44, 0.55, 0.01]
What should be done in this case
A simplistic way excluding the above scenario is
I am assuming the last axis to be the ones having sum of 0
a = torch.rand(2, 3, 2, 3)
# Forcing two cases so that we can unit test
a[0,0,0,0]=0.99
a[0,0,0,1]=0.005
a[0,0,0,2]=0.005
a[0,2,1,0]=0.99
a[0,2,1,1]=0.005
a[0,2,1,2]=0.005
a1 = torch.ones(a.size())
print("A1 size {0}".format(a1.size()))
a = a/torch.sum(a, axis=3).view(2,3,2,1).repeat((1,1,1,3))
# Unit Test: Last axis sum is 0
torch.sum(a, axis=3)
# Setting the threshold val
threshold=0.05
inv_threshold=1-threshold
a1[torch.logical_and(a < threshold, a > -threshold)]=2 # Ones which needs to be converted to 0
a1[torch.logical_or(a > inv_threshold, a < -inv_threshold)]=3 # Ones that need to be converted to 1
print("Before a was")
print(a)
a[torch.logical_and( (torch.sum(a1==1, axis=3)==0).view(2,3,2,1).repeat(1,1,1,3), a1==3)]=1
a[torch.logical_and( (torch.sum(a1==1, axis=3)==0).view(2,3,2,1).repeat(1,1,1,3), a1==2)]=0
print("After a is")
print(a)