Im trying to backpropagate through a weighted histogram inside my loss function.
Currently i am manually computing the histogram as seen in “slow” but i would like to use histogrammdd as seen in “fast” since the computation time would be a lot shorter and the result is the same. The only problem is that histogramdd has no backward pass implemented.
How would i implement the backward pass on histogramdd such that

``````x = slow(theta, x, y)
x.mean().backward()
``````

and

``````x = fast(theta, x, y)
x.mean().backward()
``````

have the same output?

This is my current code:

``````batchsize = 512
bins = 30
neurons = 10

#dummydata
x = torch.randint(0, bins, [neurons, batchsize]) # x coodinate per theta
y = torch.randint(0, bins, [neurons, batchsize]) # y coordinate per theta

def slow(theta, x, y):
theta_xy = torch.zeros([neurons, bins, bins])

for i in range(neurons):
for t, x_bin, y_bin in zip(theta[i], x[i], y[i]):
theta_xy[i, x_bin, y_bin] += t

return theta_xy.double()

def fast(theta, x, y):
neuron = torch.arange(neurons).repeat_interleave(batchsize).double()

neuron_range = torch.arange(neurons + 1, dtype=float) - 0.1
x_range = torch.arange(bins + 1, dtype=float) - 0.1
y_range = torch.arange(bins + 1, dtype=float) - 0.1

x = x.ravel()
y = y.ravel()
a = torch.vstack((neuron, x.clone(), y.clone()))

print(torch.allclose(slow(theta, x, y), fast(theta, x, y))) #same result

``````

Hi Jonas!

I don’t really understand your use case, but as s a general rule, it doesn’t
make sense to backpropagate through a histogram because the binning
process is not usefully differentiable.

In the case of your weighted histogram, you can, technically, calculate
gradients with respect to the weights as your `slow()` function demonstrates.
But I can’t think of a use case for this that isn’t contrived.

`slow()` returns a tensor with shape `[neurons, bins, bins]`, which
is `[10, 30, 30]` and therefore consists of 9000 elements. When you
calculate `x.mean()` you divide by this number of elements and doing so
results in `theta.grad` having the same identical value for all of its elements,
namely, `1 / 9000`. So, sure, you’ve computed a gradient, but it isn’t very
useful.

My assumption is that `histogramdd()` leaves `backward()` unimplemented
because any resulting gradients wouldn’t be useful.

If you have a concrete use case in mind, could you describe it in more detail
and explain which gradients you would use and how you would use them?

Best.

K. Frank