Custom Entropy Loss function

We are making a project about quantisation in a neural network layer and need to calculate the entropy of these quantised values. An example of the data that the entropy should get as input is: x = [1/8,1/8,4/8,2/8]. Currently we have the following function for entropy

def Ixxhat(bitdepth,xhat):
hist = torch.histc(xhat, 2**bitdepth, 0, 1)/(xhat.shape[0]*xhat.shape[1])
xhatProbs = torch.tensor([hist[i] * torch.log2(hist[i]) if hist[i] > 0 else 0 for i in range(len(hist))])

return - torch.sum(xhatProbs), hist

We read that as long as we use the torch functions it should work, but can someone verify this for this function and/or design an entropy function we could use instead :slight_smile:

Thanks in advance

Hi Frederik!

No, this won’t work.

Being discrete in nature, the histogramming operation is not (usefully)
differentiable, so, mathematically, there is no way to backpropagate
through it.

It is true – for reasons that I do not understand – that torch.histc()
does have a companion HistcBackward function. It doesn’t, however,
implement a derivative for histc() (because it wouldn’t make sense):

>>> import torch
>>> torch.__version__
>>> hist = torch.histc (torch.randn (10, requires_grad = True), bins = 10)
>>> hist.grad_fn
<HistcBackward object at 0x000001F87CB2E5F8>
>>> hist.sum().backward()
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "<path_to_pytorch_install>\torch\", line 255, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "<path_to_pytorch_install>\torch\autograd\", line 149, in backward
    allow_unreachable=True, accumulate_grad=True)  # allow_unreachable flag
RuntimeError: the derivative for 'histc' is not implemented.


K. Frank

Hi KFrank
Thank you for the answer. Do you know how we could implement the entropy in pytorch?

  • Frederik :slight_smile:

What’s your goal? Is it to minimize KL divergence between histograms of activations at each layer of a network to find the best parameters for quantization?