black ceiling wall

如何快速找到中位數

這個問題是小時候在一些程式競賽的書上看到的。雖然感覺很酷炫,但因為教科書上沒有教,考試也不會考,感覺會了這個也不太實用,應該只有打程式競賽時會用到,就先一直記在心底沒有拿出來想。這次看了篇文章介紹這個演算法,算是把他研究清楚一些了。有找到一個 O(n) 為線性的演算法可以用來計算中位數。其實也不只是可以計算中位數,可以改為計算第 k 位數,或是任何的百分位數。

排序的時間極限

算中位數有一個關鍵的點,是你的原始數列總數是奇數還是偶數。如果是奇數的話就直接找位於中間的數即可。如果是偶數的話,要找位於最中間位數的兩個數加起來平均。以上圖例子,中位數是 1 。以下圖來說,中位數是 1.5 。

為了確定一個長度為 n 的數列,到底要找第幾位作為中間位數。我將頭跟尾相加除以二,找到最附近兩個位數,如果整除的話取 ceil 跟 floor 是同一位。

def find_half_from(n):
    return (math.floor((n - 1)/2), math.ceil((n - 1)/2))

接著就先寫一個慢的版本。這個慢的版本也不是寫的毫無價值,等下會有個彩蛋可以用到。這個版本除了時間複雜度是 O(nlog(n)) ,比較不理想之外,沒有什麼明顯的缺點。

def slow_select(k, arr):
    return sorted(arr)[k]

def slow_find_medium(arr):
    (a, b) = find_half_from(len(arr))
    return (slow_select(a, arr) + slow_select(b, arr))/2

中位數的定義就是需要把東西排列之後取中間的數。這樣的話,光是排序就會有個極限,理論上比較型排序的最小極限時間複雜度是 O(nlog(n)),計算中位數怎麼可能比排序還快呢?

還真的有可能。關鍵點在於,找到中位數並不一定要排序好才能找到。你只要知道你的這個數,跟數列中的其他數字比起來位於中間位即可。

快速計算中位數

我們先選一個其中的元素作為 pivot ,用這數字跟數列中的數字比大小,把數列分成三堆。大的放右邊一堆,小的放左邊一堆,一樣的話放中間一堆。如果剛好抽中位數,k 就會介於左邊那堆跟右邊那堆之間,也就是 k 會小於左邊堆加中間堆的大小。如果沒抽中,你也會知道要往小的那堆找還是大的那堆找,因為你知道你要找的是第 k 個。

這種用 pivot 分成堆的概念,有點類似快速排序法,因為是同一個人 Tony Hoare 發明的。至於怎麼選呢? 簡單起見我們先選第一個元素作為 pivot 。如果剛好 pivot 就是要我們要選的,那就直接回傳,否則就判斷往左堆找或是往右堆找。往右堆找的話,就要記得減去對應的 offset 。後續會提到更好的選擇 pivot 的方法,來降低最差情況下的時間複雜度。

要分成三組,而不是分成左右兩組就好,是因為怕 pivot 的數字剛好切的不在 arr 裡面,仍然可以順利執行。最後面我們會用 median_of_medians 來選,就有可能選到的 pivot 不在 arr 內。

def trivial_pick_pivot(arr):
    return arr[0]

def quick_select(k, arr, pivot_function = trivial_pick_pivot):
    pivot = pivot_function(arr)    
    right = [x for x in arr if x > pivot]
    mid = [x for x in arr if x == pivot]
    left = [x for x in arr if x < pivot]
    if k < len(left):
        return quick_select(k, left)
    if k < len(left) + len (mid):
        return pivot
    if k >= len(left) + len (mid):
        return quick_select(k - (len(left) + len(mid)), right)

時間複雜度

想像一下這個演算法的時間複雜度。平均來講每次用 pivot 比一輪,比過一次之後就可以把問題縮小,平均情況可以縮小成 n/2 的子問題。這就是等比級數,因此時間複雜度為 O(n) 。但最差情況每次只能所小成 n-1 的子問題,這樣的話因為是等差級數, worst case 時間複雜度仍是 O(n^2) 。

那怎麼辦! 這個演算法突然看起來很不帥了!

Medians of Medians

有幾個神人研發出 Median of medians 的方法,可確保每次選 pivot 的時候都可以去掉一部分比例 (例如左上象限) 的元素,轉換為更小的子問題。

def chunked(arr, size):
    return [arr[i : i + size] for i in range(0, len(arr), size)]

def median_of_medians(arr):
    if len(arr) <= 5:
        return slow_find_median(arr)
    chunks = chunked(arr, 5)
    medians = [sorted(chunk)[2] for chunk in chunks if len(chunk) == 5]
    return median_of_medians(medians)

雖然會花費額外的時間選 pivot ,但可以確保最差情況下的時間複雜度是 O(n)。但我沒有深究這是怎麼得到的,感覺計算起來有點複雜。我的演算法分析已經忘了差不多了。

直覺來想確實有機會,因為每次 quick_select 都可以利用 pivot 至少去除紅框中大約 (3n/5)/2 的元素,轉變成 7n/10 的子問題,是一個等比級數的 O(n) 等級問題。產生 pivot 本身也只需要公比 1/5 的等比級數,這也是 O(n) 等級的問題。


Posted

in

by

Comments

Leave a Reply

Your email address will not be published. Required fields are marked *