The soft histogram might be usable instead of torch.histc.
import torch
import torch.nn as nn
import numpy as np
data = 50 + 25 * torch.randn(1000)
hist = torch.histc(data, bins=10, min=0, max=100)
print(hist)
class SoftHistogram(nn.Module):
def __init__(self, bins, min, max, sigma):
super(SoftHistogram, self).__init__()
self.bins = bins
self.min = min
self.max = max
self.sigma = sigma
self.delta = float(max - min) / float(bins)
self.centers = float(min) + self.delta * (torch.arange(bins).float() + 0.5)
def forward(self, x):
x = torch.unsqueeze(x, 0) - torch.unsqueeze(self.centers, 1)
x = torch.sigmoid(self.sigma * (x + self.delta/2)) - torch.sigmoid(self.sigma * (x - self.delta/2))
x = x.sum(dim=1)
return x
softhist = SoftHistogram(bins=10, min=0, max=100, sigma=3)
data.requires_grad = True
hist = softhist(data)
print(hist)
hist.sum().backward()
print(data.grad.max())
Output:
tensor([ 27., 59., 89., 134., 153., 183., 142., 97., 52., 32.])
tensor([ 26.2054, 60.0647, 87.4429, 135.9337, 150.7281, 184.4028, 139.8741,
97.9176, 52.9988, 30.5099], grad_fn=<SumBackward2>)
tensor(0.7499)