"Peaking" a distribution tensor

I have a tensor of probability distributions t of size (n, c, h, w), such that t.sum(dim=1) == 1.

What’s the best way to retrieve low entropy predictions, and to “peak” them such that they become one-hot, preserving differentiability? Retrieving may be performed also using a threshold on the probabilities (e.g. p such that exists a p[i] > 0.9).

For example, if I have [0.05, 0.9, 0.05] the peaked output would be [0., 1., 0.].

How can I efficiently perform the peaking?

1 Like

A somewhat naive way could be subtracting the threshold from t, then give this to a ReLU activation with intercept 0 and slope 1. Lastly divide by the sum of the result to go back to summing 1.

1 Like

@AreTor There are many flavours of combinations


a=[0.44, 0.55, 0.01]

What should be done in this case

A simplistic way excluding the above scenario is

I am assuming the last axis to be the ones having sum of 0

a = torch.rand(2, 3, 2, 3)

# Forcing two cases so that we can unit test


a1 = torch.ones(a.size())
print("A1 size {0}".format(a1.size()))
a = a/torch.sum(a, axis=3).view(2,3,2,1).repeat((1,1,1,3))
# Unit Test: Last axis sum is 0 
torch.sum(a, axis=3)

# Setting the threshold val
a1[torch.logical_and(a < threshold, a > -threshold)]=2 # Ones which needs to be converted to 0
a1[torch.logical_or(a > inv_threshold, a < -inv_threshold)]=3 # Ones that need to be converted to 1

print("Before a was")

a[torch.logical_and( (torch.sum(a1==1, axis=3)==0).view(2,3,2,1).repeat(1,1,1,3), a1==3)]=1
a[torch.logical_and( (torch.sum(a1==1, axis=3)==0).view(2,3,2,1).repeat(1,1,1,3), a1==2)]=0

print("After a is")
1 Like

@GregorySech This should be tried, as it would be an awesome way to use activation functions to do intelligent approximations and rounding offs

In the end, I solved with the following piece of code

    # x: (n, c, h, w) tensor -- t: threshold
    mask = x > t  # find values over threshold
    x[mask] = x[mask] / x[mask]  # peak these values to 1.
    x = x * (mask + ~mask.any(dim=1, keepdims=True))  # adjust to get prob again

that works for thresholds greater than 0.5.