Differentiable torch.histc?

Hi everyone, I found out that the derivative for ‘histc’ is not implemented. Is there a way to easily implement a histogram function with derivates?

Thank you for your time

2 Likes

Heres’s something that might fill the need:

class HistoBin(nn.Module):
    def __init__(self, locations = np.arange(0,1,.1), radius=.2, norm = True):
        super(HistoBin, self).__init__()
        
        self.locs = locations
        self.r = radius
        self.norm = norm
    
    def forward(self, x):
        
        counts = []
        
        for loc in self.locs:
            dist = torch.abs(x - loc)
            #print dist
            ct = torch.relu(self.r - dist).sum(1) 
            counts.append(ct)
        
        out = torch.stack(counts, 1)
        
        if self.norm:
            summ = out.sum(1) + .000001
            return (out.transpose(1,0) / summ).transpose(1,0)
        return out
1 Like

Gaussian expansion could be used as an approximate histogram:

import torch
import torch.nn as nn
import numpy as np

data = 50 + 25 * torch.randn(1000)

hist = torch.histc(data, bins=10, min=0, max=100)

print(hist)

class GaussianHistogram(nn.Module):
    def __init__(self, bins, min, max, sigma):
        super(GaussianHistogram, self).__init__()
        self.bins = bins
        self.min = min
        self.max = max
        self.sigma = sigma
        self.delta = float(max - min) / float(bins)
        self.centers = float(min) + self.delta * (torch.arange(bins).float() + 0.5)

    def forward(self, x):
        x = torch.unsqueeze(x, 0) - torch.unsqueeze(self.centers, 1)
        x = torch.exp(-0.5*(x/self.sigma)**2) / (self.sigma * np.sqrt(np.pi*2)) * self.delta
        x = x.sum(dim=1)
        return x

gausshist = GaussianHistogram(bins=10, min=0, max=100, sigma=6)

data.requires_grad = True
hist = gausshist(data)
print(hist)

hist.sum().backward()
print(data.grad)

Output:

tensor([ 41.,  63.,  98., 123., 145., 158., 128., 106.,  67.,  29.])
tensor([ 38.4469,  65.1605,  92.4981, 122.9916, 145.3853, 150.2217, 130.5083,
        102.5501,  69.8868,  31.8439], grad_fn=<SumBackward2>)
tensor([ 6.4244e-02, -7.6090e-02,  5.6231e-05,  9.3874e-04, -7.5421e-04,
        -2.4916e-04, -1.0152e-03,  6.8575e-04,  1.8976e-04,  6.6737e-04,
        -7.3137e-05,  4.9042e-02,  1.8404e-04, -2.6847e-04, -5.8353e-04,

The soft histogram might be usable instead of torch.histc.

import torch
import torch.nn as nn
import numpy as np

data = 50 + 25 * torch.randn(1000)

hist = torch.histc(data, bins=10, min=0, max=100)

print(hist)

class SoftHistogram(nn.Module):
    def __init__(self, bins, min, max, sigma):
        super(SoftHistogram, self).__init__()
        self.bins = bins
        self.min = min
        self.max = max
        self.sigma = sigma
        self.delta = float(max - min) / float(bins)
        self.centers = float(min) + self.delta * (torch.arange(bins).float() + 0.5)

    def forward(self, x):
        x = torch.unsqueeze(x, 0) - torch.unsqueeze(self.centers, 1)
        x = torch.sigmoid(self.sigma * (x + self.delta/2)) - torch.sigmoid(self.sigma * (x - self.delta/2))
        x = x.sum(dim=1)
        return x

softhist = SoftHistogram(bins=10, min=0, max=100, sigma=3)

data.requires_grad = True
hist = softhist(data)
print(hist)

hist.sum().backward()
print(data.grad.max())

Output:

tensor([ 27.,  59.,  89., 134., 153., 183., 142.,  97.,  52.,  32.])
tensor([ 26.2054,  60.0647,  87.4429, 135.9337, 150.7281, 184.4028, 139.8741,
         97.9176,  52.9988,  30.5099], grad_fn=<SumBackward2>)
tensor(0.7499)
1 Like

This function only seems to work for your input. For example, input values between -1 and 1 don’t work

In fact, it only works for the specific data and parameters provided

Please adjust parameters to your data.

Well there is only one parameter, sigma. I’ve tried numerous. None work. I welcome you to try. Simply replace “data = 50 + 25 * torch.randn(1000)” with “data = torch.randn(1000)”. I’m guessing neither gaussianhist or softhist generalise particularly well and require input data to have a distribution resembling a gaussian one.

Please try min=-2, max=2, sigma=3/25. I did shift and scale the parameters to adjust them to the normal data.

Thanks for the reply. Unfortunately, that doesn’t work. That was the first thing I tried. It works even less so when you increase bins.

I am sorry. The sigma should be 3*25.

When min=-2, max=2, sigma=3*25, I got the following:

tensor([ 36.,  72., 105., 136., 134., 148., 132.,  92.,  55.,  37.])
tensor([ 36.3019,  71.9146, 104.1209, 138.0593, 133.3760, 144.7053, 134.9035,
         92.9224,  52.8923,  36.8525], grad_fn=<SumBackward2>)
tensor(18.4180)

The parameter name “sigma” is a bit confusing. I think “kappa” or “k” is better.

Yep it’s getting there. Now set bins=255.

I get tensor([ 0., 0., 0., 2., 0., 1., 1., 1., 2., 0., 1., 0., 0., 0.,
1., 1., 1., 1., 1., 2., 3., 0., 2., 1., 1., 2., 0., 2.,
1., 2., 3., 0., 0., 2., 1., 0., 3., 4., 2., 3., 1., 2.,
4., 2., 1., 1., 4., 2., 5., 1., 4., 3., 3., 4., 0., 2.,
5., 6., 4., 3., 8., 1., 4., 5., 3., 0., 3., 5., 4., 4.,
2., 6., 3., 7., 2., 10., 2., 5., 10., 3., 7., 6., 6., 7.,
4., 3., 2., 7., 4., 7., 3., 4., 5., 5., 5., 2., 3., 6.,
3., 9., 6., 7., 7., 6., 5., 7., 5., 6., 6., 5., 7., 3.,
5., 4., 2., 7., 2., 10., 4., 5., 11., 11., 5., 9., 6., 7.,
4., 7., 6., 9., 7., 11., 6., 7., 7., 5., 6., 5., 6., 4.,
9., 8., 7., 12., 9., 8., 1., 6., 4., 5., 5., 6., 5., 7.,
11., 11., 4., 7., 9., 3., 4., 6., 4., 3., 4., 10., 3., 2.,
6., 11., 4., 4., 6., 7., 4., 5., 5., 3., 3., 3., 5., 4.,
8., 2., 4., 4., 2., 9., 6., 4., 4., 2., 3., 6., 3., 2.,
2., 5., 3., 0., 2., 1., 6., 2., 3., 0., 6., 4., 0., 4.,
1., 0., 2., 3., 2., 5., 1., 2., 5., 2., 2., 1., 3., 4.,
1., 5., 2., 2., 1., 4., 3., 2., 3., 3., 1., 0., 4., 0.,
1., 0., 0., 2., 2., 0., 1., 0., 0., 2., 1., 1., 0., 2.,
0., 2., 0.])
tensor([0.2294, 0.2264, 0.4174, 0.6920, 0.8129, 0.8522, 0.9543, 1.0392, 1.0053,
0.8288, 0.6072, 0.4167, 0.3240, 0.4156, 0.6070, 0.8465, 1.0836, 1.1995,
1.3867, 1.6308, 1.6203, 1.4406, 1.3543, 1.2456, 1.1760, 1.1920, 1.1754,
1.2341, 1.4459, 1.7166, 1.6359, 1.1794, 0.9422, 1.0185, 1.1744, 1.5338,
2.1595, 2.6610, 2.7007, 2.4503, 2.2206, 2.2761, 2.3876, 2.1789, 1.9693,
2.1244, 2.4716, 2.7490, 2.9155, 2.9682, 3.0545, 3.1094, 3.0652, 2.7475,
2.4773, 2.9287, 3.8287, 4.4061, 4.3883, 4.2728, 4.3133, 4.1353, 3.9032,
3.5445, 3.0214, 2.7678, 3.0349, 3.4247, 3.5377, 3.5801, 3.9689, 4.5266,
4.7314, 4.8329, 5.1060, 5.3357, 5.2748, 5.6173, 6.1142, 6.1377, 6.0390,
6.0027, 5.9284, 5.4285, 4.6102, 4.0881, 4.1167, 4.4975, 4.7918, 4.8103,
4.5924, 4.5281, 4.7080, 4.6561, 4.1667, 3.7260, 3.9051, 4.5131, 5.3062,
6.2507, 6.7920, 6.8175, 6.6299, 6.2593, 5.9633, 5.8159, 5.6079, 5.4935,
5.5841, 5.6613, 5.3981, 4.9287, 4.5209, 4.0846, 3.9914, 4.4455, 5.1549,
5.9601, 6.4060, 7.0180, 8.1164, 8.5827, 8.1205, 7.4650, 6.5912, 5.9549,
5.8980, 6.1648, 6.8313, 7.6569, 8.1927, 8.1415, 7.4805, 6.9542, 6.6359,
6.1840, 5.7488, 5.6370, 5.8313, 6.2864, 7.1439, 7.9608, 8.6149, 8.8756,
8.0024, 6.5094, 5.2868, 4.8778, 4.9382, 5.0370, 5.1563, 5.5232, 6.3218,
7.4457, 8.2003, 8.1334, 7.5174, 7.0089, 6.4302, 5.5020, 4.8912, 4.7707,
4.6245, 4.5781, 4.8791, 5.0248, 4.6596, 4.7419, 5.6951, 6.3835, 6.0732,
5.7131, 5.5538, 5.1487, 4.7616, 4.6408, 4.3855, 3.9211, 3.7476, 3.9929,
4.3637, 4.6391, 4.6718, 4.3453, 4.0791, 4.0660, 4.5383, 5.3572, 5.4346,
4.6449, 3.8484, 3.4765, 3.6354, 3.7055, 3.2702, 2.8501, 2.7852, 2.8737,
2.6213, 2.1645, 2.0984, 2.5559, 2.9686, 2.8456, 2.6406, 2.8016, 3.1273,
2.9296, 2.4468, 2.0567, 1.5683, 1.3667, 1.7574, 2.3771, 2.7778, 2.8132,
2.7282, 2.8660, 2.8734, 2.4603, 2.0452, 2.0514, 2.4907, 2.9116, 3.0435,
2.9171, 2.5364, 2.2664, 2.3686, 2.6276, 2.6501, 2.5337, 2.4915, 2.2262,
1.7326, 1.5266, 1.5451, 1.2724, 0.8348, 0.6324, 0.7961, 1.0441, 1.0182,
0.7880, 0.6156, 0.5596, 0.7096, 0.9891, 1.0293, 0.8919, 0.8464, 0.9157,
0.9453, 0.8344, 0.6318])

There may be some errors of about 1 because of the softening. Keep in mind that Heaviside step functions in histogram are replaced with sigmoid functions.

Yeah. i don’t think this is a solvable problem. You can sharpen the sigmoid function, but in doing so you’re going to increase your gradients. As your “sharpened” sigmoid approaches the heavy-side function, your gradients are going to diverge. I think histograms are a lost cause unless you use them as the very first input of your model, where tensors don’t require gradients.

You are right. Actually, I didn’t use histograms in any neural network, and don’t want to use them. I just created the soft histogram in analogy with the soft max, focusing on the softening.

It’s a shame though. Histograms are such useful features. Ideal to plug into your final dense layer. Oh well

I hope your attempt succeeds.
Good luck!

Hi @Tony-Y

Thank you for your code snippet. It is very helpful.
I have one suggestion:
In order to make model.cuda() work, we need to wrap the self.centers with nn.Parameter() like:

self.centers = nn.Parameter(self.centers, requires_grad=False)

Also, if we are expecting the first dimension is the batch size, I update the
forward() function as follows:


    def forward(self, x):
        x = torch.unsqueeze(x, 1) - torch.unsqueeze(self.centers, 1)
        x = torch.exp(-0.5*(x/self.sigma)**2) / (self.sigma * np.sqrt(np.pi*2)) * self.delta
        x = x.sum(dim=-1)
        x = x/x.sum(dim=-1).unsqueeze(1) # normalization
        return x
2 Likes