How can I calcaulte 5% & 95% of of the whole datasets of images?

Hi

I have a pretty big dataset of images. I would like to calculate the 5% and 95% of pixel ranges for the whole dataset.

I am wondering if there is any method for that? Also, how can I do that in pytorch?

Please let me know if you have any quetions

Hi Ali,

if your values are 8 or 16 bit integers (which is quite common), I’d probably just compute a histogram.
If they are not, you could do this iteratively: Quantize to 8 bits by rounding down and up (keeping statistics for both rounded-down and rounded-up). Then you know the 5% percentile is between the rounded-down 5% and the rounded up 5% percentile and then you can just quantize that range to 8 bits (if you want it easy, just clamp the range) - so everything outside the range gets put on the boundaries, but the 5% quantile is in the “higher resolution range”.

To make things concrete

data = torch.randn(500000, dtype=torch.double)
q05_true = data.sort().values[int(len(data) * 0.05)]  #  if you want data.quantile(0.05), you would have to match their interpolation
# upper and lower bound
q05_max = torch.max(data)
q05_min = torch.min(data)

HIST_SIZE = 1000

done = False
while not done:
    # each loop means a loop over your dataset
    transformed = ((data - q05_min) * HIST_SIZE / (q05_max - q05_min)).clamp(min=0, max=HIST_SIZE)
    ceil = transformed.ceil().long()
    floor = transformed.floor().long()

    vals_c, counts_c = torch.unique(ceil, return_counts=True)
    vals_f, counts_f = torch.unique(floor, return_counts=True)

    # refined upper and lower bound
    q05_max_new = (vals_c[(counts_c.cumsum(-1).double()/data.numel() > 0.05).nonzero().min()]) * (q05_max - q05_min) / HIST_SIZE + q05_min
    q05_min_new = (vals_f[(counts_f.cumsum(-1).double()/data.numel() <= 0.05).nonzero().max()]) * (q05_max - q05_min) / HIST_SIZE + q05_min

    assert q05_min_new <= q05_true <= q05_max_new, f"{q05_min_new} <= {q05_true} <= {q05_max_new}"
    q05_min, q05_max = q05_min_new, q05_max_new
    done = len(vals_c) == 3
    print(f"{q05_min}, {q05_max}")

vals, counts = data.clamp(min=q05_min, max=q05_max).unique(return_counts=True)
q05 = vals[(counts.cumsum(-1) <= int(len(data) * 0.05)+1).nonzero().max()]
print("found", q05.item(), "true", q05_true.item())

Actually a fun task, thanks for sharing the problem!

Best regards

Thomas

2 Likes