Autograd for weighted histogram

Im trying to backpropagate through a weighted histogram inside my loss function.
Currently i am manually computing the histogram as seen in “slow” but i would like to use histogrammdd as seen in “fast” since the computation time would be a lot shorter and the result is the same. The only problem is that histogramdd has no backward pass implemented.
How would i implement the backward pass on histogramdd such that

x = slow(theta, x, y)
x.mean().backward()
print(theta.grad)

and

x = fast(theta, x, y)
x.mean().backward()
print(theta.grad)

have the same output?

This is my current code:

batchsize = 512
bins = 30
neurons = 10

#dummydata
theta = torch.randn([neurons, batchsize], requires_grad=True)
x = torch.randint(0, bins, [neurons, batchsize]) # x coodinate per theta
y = torch.randint(0, bins, [neurons, batchsize]) # y coordinate per theta

def slow(theta, x, y):
    theta_xy = torch.zeros([neurons, bins, bins])
       
    for i in range(neurons):
        for t, x_bin, y_bin in zip(theta[i], x[i], y[i]):
            theta_xy[i, x_bin, y_bin] += t

    return theta_xy.double()

def fast(theta, x, y):
    neuron = torch.arange(neurons).repeat_interleave(batchsize).double()

    neuron_range = torch.arange(neurons + 1, dtype=float) - 0.1
    x_range = torch.arange(bins + 1, dtype=float) - 0.1
    y_range = torch.arange(bins + 1, dtype=float) - 0.1

    x = x.ravel()
    y = y.ravel()
    a = torch.vstack((neuron, x.clone(), y.clone()))

    return torch.histogramdd(a.T, bins=[neuron_range, x_range, y_range] , weight=theta.ravel().double())[0]

print(torch.allclose(slow(theta, x, y), fast(theta, x, y))) #same result

Hi Jonas!

I don’t really understand your use case, but as s a general rule, it doesn’t
make sense to backpropagate through a histogram because the binning
process is not usefully differentiable.

In the case of your weighted histogram, you can, technically, calculate
gradients with respect to the weights as your slow() function demonstrates.
But I can’t think of a use case for this that isn’t contrived.

slow() returns a tensor with shape [neurons, bins, bins], which
is [10, 30, 30] and therefore consists of 9000 elements. When you
calculate x.mean() you divide by this number of elements and doing so
results in theta.grad having the same identical value for all of its elements,
namely, 1 / 9000. So, sure, you’ve computed a gradient, but it isn’t very
useful.

My assumption is that histogramdd() leaves backward() unimplemented
because any resulting gradients wouldn’t be useful.

If you have a concrete use case in mind, could you describe it in more detail
and explain which gradients you would use and how you would use them?

Best.

K. Frank