Backward function of median

It would be nice if I could take the median of losses in a batch, instead of the mean, and get gradients based on that. But does torch.median() have a useful backward function, or is it only indexing like torch.max()? Thanks for your help.

You already can, but the median gradient is generally very sparse by the nature of the median: only the element (pytorch does not interpolate for an even number of inputs) that has the value of the median has a gradient. The way ties for the are handled in pytorch likely is not based on rigorous mathematics, probably because the case lacks practical relevance.

Best regards

Thomas

2 Likes

How does the gradient computation work exactly here? I’ve spent hours researching for an answer but I was not able to match pytorch’s gradients from torch.median().

Consider

a = torch.randn(3,4, requires_grad=True)
b = a.median(1)
print("a =", a)
print("b = ",b)

giving

a = tensor([[-0.8903,  1.3827,  0.8102,  0.2972],
        [-0.5800, -0.3084, -0.6165, -0.5376],
        [-0.5144, -1.6414, -0.1507,  0.5274]], requires_grad=True)
b = torch.return_types.median(
values=tensor([ 0.2972, -0.5800, -0.5144], grad_fn=<MedianBackward1>),
indices=tensor([3, 0, 0]))

Then the backward will just be the grad_out scattered to the appropriate indices:

b.values.sum().backward()
print(a.grad)

gives

tensor([[0., 0., 0., 1.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.]])

That also is why the indices are non-optional: They need to be computed for the backward, anyways.

Best regards

Thomas

1 Like

Thank you Tom. I figured this out shortly after asking this question. :). For those interested:

def backward_median(x, v):
    median_x = np.median(x)

    v_original_order = np.zeros(x.shape)

    v_original_order_hits = v_original_order.copy()

    v_original_order_hits[x != median_x] = 0.

    v_original_order_hits[x == median_x] = v[x == median_x]

    return v_original_order_hits

This is a numpy version.