Differentiable torch.histc?

What is Gaussian expantion? There may be something wrong with your link.

Please see Methods section.

  1. Gaussian feature expansion of the inter-atomic distances
1 Like

Hi @Yihui_Ren
There are errors occur in my case with your forward function.
I think the following might be much better to use. :wink:

    def forward(self, x):
        n, l = x.size()
        # (2,1,l)
        x = torch.unsqueeze(x, 1)
        # (2,256,1)
        self.centers = self.centers.unsqueeze(0).repeat(n, 1).unsqueeze(2)
        # (2,256,l)
        x = x - self.centers
        x = torch.exp(-0.5 * (x / self.sigma) ** 2) / (self.sigma * np.sqrt(np.pi * 2)) * self.delta
        # (2,256)
        x = x.sum(dim=2)
        return x

Hi @ProkaryMonster,

Could you please elaborate on why the code wasn’t working? It might help others to understand it better.

I think I made some implicit assumption that we are going to do histc on the last dimension, therefore dim=-1. Is this the cause of trouble for your case?

Thank you.

Hi @Yihui_Ren
I find out that the pixels of my result always share close values and the grads of one sample in the batch keep 0.
I cannot figure out why it happens, but I’ve noticed that the centers in your code might not repeat across the batch dim, and after fixing that, my code works well.
Might the error only occurs in my case?

hi @Tony-Y , can you explain why you multiply by self.delta? in https://www.nature.com/articles/ncomms13890.pdf (Quantum-chemical insights from deeptensor neural networks) Eq.5, there is no multiplication. what is surprising is that such multiplication approximates better the true histogram computed using torch.histc. thanks.

hi @Tony-Y , can you clarify more how you come up with this difference? and what is computing exactly?

and why you are shifting left and right from the difference from the centers?

any references to these equations would be helpful. thanks

also, using a sigmoid will tend to give almost zero gradient (depending on the required accuracy of the histogram that depends partially on the temperature sigma). thanks

I think the interval (delta) is necessary for the normalization.
calculas

I hope this graph helps your understanding.

I think you can simply use exp(-(x/sigma)**2), i.e. gaussian RBF, all other normalization terms can be replaced with normalization of final output (to sum(output)=len(input), if one wishes to emulate histc)

Your thought seems to have already pointed out by @Yihui_Ren.

I rather wanted to point out that with big input this wastes memory+time with unnessesary multiplications, and at least one extra tensor is kept around for backprop, I believe.

hi @Tony-Y, thanks. were the 2 figures taken from some reference (paper/book)? can you provide the references? thanks

No references. I created the 2 figures for you.

hi @Tony-Y,

if we set:

  • L = f(x + delat/2)
  • R = f(x - delta/2)

–>

  • from your code, it seems that you implemented L - R which is ok because L and R are just two sigmoid where the first is shifted to the left and the second to the right. this difference equals one only when x0 - delta/2 < x < x0 + delta/2 which is what we are looking for.
  • from the second figure, it seems to me that to add +1 to the bin, x needs to be in the blue window with value 1 <==> the intersection of L and R.
    to check this condition one can simply do: L * (1 - R).
  • i computed (L * (1 - R) - (L - R)).abs().sum(). found 0. meaning, they are the same (they are supposed to).
  • the sigmoid in the figure may need (-) to be 1/(1 + exp(-kx)).

i didn’t get the relation between your code and your second figure right away.
thanks again

That’s right. Sorry for this typo.

If you know Heaviside step function, H(x), the rectangular function is made from H(x+0.5) - H(x-0.5). The H(x) can be approximated by the logistic function. The second figure illustrates those.

A histogram bin’s value can be computed by the sum of rectangular function’s values.

thank you again for the help.

could you use a “straight through” sigmoid?

how about this one?

class SoftHistogram(nn.Module):
    def __init__(self, bins, min, max, sigma):
        super(SoftHistogram, self).__init__()
        self.bins = bins
        self.min = min
        self.max = max
        self.sigma = sigma
        self.delta = float(max - min) / float(bins)
        self.centers = float(min) + self.delta * (torch.arange(bins).float() + 0.5)
    
    ## previously by Tony-Y
    def forward(self, x):
        ## input should be reshaped into [B, len]
        x = x.view(B, -1)
        x = torch.unsqueeze(x, 1) - torch.unsqueeze(self.centers, -1)
        x = torch.sigmoid(self.sigma * (x + self.delta/2)) - torch.sigmoid(self.sigma * (x - self.delta/2))
        x = x.sum(dim=-1)
        return x

    ### now
    def forward_cmp(self, input):
        ## input should be reshaped into [B, len]
        b, c, h, w = input.shape
        input = input.view(B, -1)
        x = torch.unsqueeze(input, 1) - torch.unsqueeze(self.centers, -1)
        x = torch.sigmoid(self.sigma * x)
        diff = torch.cat([torch.ones((b,1,h*w),device=input.device), x],dim=1) - torch.cat([x, torch.zeros((b,1,h*w),device=input.device)],dim=1)

        diff = diff.sum(dim=-1)
        diff[:,-2] += diff[:,-1]
        return diff[:,:-1]

With

data = (256*torch.rand((B,1,64,64))).round_()
hist_nondiff = torch.histc(data, bins=256, min=0, max=256)

as exampled test, the new function could reach error=0.