Using histogram operation before loss function

My problem is that I want to calculate the KL-divergence from two histograms. When I get the outputs from the NN output, the easiest way to go is to calculate the histogram with torch.histc or torch.histogram. I do the same thing for the labels. The problem is that the histogram operation is non-differentiable so when the backpropagation starts, it cannot calculate the gradients from it.

I had three ideas from here:

  1. detach the histogram operation from the computation graph:
  • if i do this, there is no step from outputs and labels to the final loss function, because it gets the detached values from the histogram operation, so when we differentiate it, it cannot step back because of the break in the graph
  1. Create a gradient function that just passes the the received gradients on backwards without changing the values.
  • i implemented this, however the loss function is differentiated with respect to bin numbers (or normalized bin values). If I have 10 bins, and 20 outputs and I want to pass on the gradients unchanged than 10 values will be passed backwards and I have to expand the gradient list with arbitrary numbers so I match the original 20 outputs.
  1. just pass on the loss value to every output before the histogram, effectively cirumventing the histogram operation
  • the derivative of the loss is calculated with respect to outputs before the histogram (does this even make any sense doing?)
  • A custom gradient distribution function sohuld be implemented here.

I do not think that going deeper than this into the computation has any benefits because of the time invested and the problems that could arise further down the line when trying to use the model.

Other theoretical questions:

If I use the histograms as input for KL-divergence loss, should I:

  • should I calculate the PDFs from the histograms or are they fine if I use the counted numbers in the bins
  • make the predictions logarithmic or not (KLDivLoss — PyTorch 2.0 documentation)? The link here suggests that I should
  • should I apply softmax (I think it is just for the example in the link above and I already have a distribution with a hsitogram, just making sure)

My implemented custom histogram operation with a passing gradient function (number 2. in the list):

class HistFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        bins = 5
        min=-16.1 
        max=35.8
        output = torch.histc(input, bins=bins, min=min, max=max)
        ctx.save_for_backward(input, density=True)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_input = torch.zeros_like(input)
        grad_input[:grad_output.size(0)] = grad_output
        return grad_input

Thanks in advance!