I noticed an issue while using torch.median()
on a PyTorch tensor. It’s not returning the correct median of all elements. Here’s an example:
import torch
e = torch.tensor([[5., 8., 0.],
[8., 9., 6.]])
print(torch.median(e)) # Output: tensor(6.)
However, when I sort and manually compute the median:
sorted_e = torch.sort(e.flatten()).values
median_value = (sorted_e[len(sorted_e) // 2 - 1] + sorted_e[len(sorted_e) // 2]) / 2
print(median_value) # Output: 7.0
The expected median should be 7.0, but torch.median(e)
gives 6.0.