美柑の部屋

涙は見せないと誓った。

Loading…

面试常见问题盘点:第k大元素

数组本身结构简单,但是对数组进行操作会涉及到很多经典算法,这使得数组成为了面试中最热门的考点之一。这篇文章就总结一下求数组的第k小/第k大元素,或者中位数的问题。

一、求数组中第k小/第k大元素

问题:输入一个数组以及一个整数k,返回数组中第k大的元素。
看到这个问题,我们的第一反应可能是先将输入的数组排序,这样一来,算法的时间复杂度就至少是O(nlogn)。其实,解决这类问题有一种更快的方法,它是基于快速排序的partition函数实现的。

1.1 partition函数概述

partition函数的作用是选择一个标记元素pivot(一般取数组第一个元素为标记元素),将原数组划分为两部分,左边一部分元素比pivot小,右边一部分元素比pivot大,pivot位于中间。该函数的大致流程如下:
1. 初始:记录两个下标p1p2p1记录了比pivot小的部分的最后一个元素的下标,初始值为0p2记录了当前正在遍历的元素的下标,初始值为1
2. 迭代:每当p2遍历到一个比pivot小的元素,则将p1++,然后将arr[p1]的值和arr[p2]的值进行交换。
3. 结束:在遍历完毕后,arr[1..p1]存放的就是所有小于pivot的元素,arr[p1+1..p2]存放的就是所有的大于pivot的元素。我们将arr[0]arr[p1]交换,这样标记元素pivot就位于下标为p1的位置,它左侧的元素小于它,它右侧的元素大于它。

1.2 基于partition函数的求第k大元素的算法

首先对整个数组进行一次划分,观察大于pivot的元素数量size。如果size == k-1,说明pivot正好是我们要找的第k大的元素,直接返回pivot即可;如果size > k-1,说明第k大的元素位于右半部分,我们需要对右半部分递归调用partition函数,找右半部分第k大的元素;如果size < k-1,说明第k大的元素位于左半部分,我们需要对左半部分递归调用partition函数,找左半部分第k-1-size大的元素。

数组第k大的元素
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
int findKthLargest(vector<int>& nums, int k) {
    if(nums.size() == 1)
        return nums[0];
    return findResult(nums, 0, nums.size()-1, k);
}

int findResult(vector<int>& nums, int start, int end, int k){
    if(start == end && k == 1)
        return nums[start];

    //Partition:
    int p1 = start;
    int p2 = start + 1;
    for(; p2 <= end; p2++){
        if(nums[p2] < nums[start]){
            p1++;
            int tmp = nums[p2];
            nums[p2] = nums[p1];
            nums[p1] = tmp;
        }
    }
    int tmp = nums[start];
    nums[start] = nums[p1];
    nums[p1] = tmp;

    int leftSize = p1 - start;
    int rightSize = end - p1;
    if(k == rightSize + 1)
        return tmp;
    else if(k <= rightSize)
        return findResult(nums, p1+1, end, k);
    else
        return findResult(nums, start, p1-1, k-rightSize-1);
}

该算法的平均时间复杂度为O(n),与快速排序一样,最坏情况下,时间复杂度为O(n2)

二、数组前k小/前k大的元素

问题:输入一个数组以及一个整数k,返回数组中前k小的元素。

2.1 方法一:基于partition函数的算法

定义partition函数,其返回值为pivot在划分之后所处的下标。首先对整个数组进行一次划分,并观察函数的返回值。
如果返回值等于k,我们知道arr[0..k-1]肯定都比pivot小,arr[k+1..n-1]肯定都比pivot大,换句话说,arr[0..k-1]就是数组的前k小的元素;如果返回值小于k,我们对右半部分进行递归调用,直到返回值等于k为止;如果返回值大于k,我们对左半部分进行递归调用,直到返回值等于k为止。
与上一个问题相同,该算法的平均时间复杂度为O(n)

数组前k小的元素
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
int partition(vector<int>& arr, int start, int end){
    if(start == end)
        return start;

    int p0 = start;
    int p1 = start + 1;
    int pivot = arr[start];
    for(; p1 <= end; p1++){
        if(arr[p1] < pivot){
            p0++;
            int tmp = arr[p0];
            arr[p0] = arr[p1];
            arr[p1] = tmp;
        }
    }
    int tmp = arr[p0];
    arr[p0] = arr[start];
    arr[start] = tmp;
    return p0;
}

vector<int> findLeastKElements(vector<int>& arr, int k){
    vector<int> res;
    if(k <= 0 || arr.size() == 0)
        return res;

    int start = 0;
    int end = arr.size() - 1;
    int index = partition(arr, start, end);
    while(index != k){
        if(index < k){
            start = index + 1;
            index = partition(arr, start, end);
        }else{
            end = index - 1;
            index = partition(arr, start, end);
        }
    }
    for(int i=0; i<k; i++)
        res.push_back(arr[i]);
    return res;
}

上述方法尽管效率较高,但是也存在一个致命的问题:运行时需要将所有数据载入内存。这就决定了该算法不适合海量数据的场合,比如从一亿个数中找出前一百个数。为了解决这个问题,我们便引入了一种基于堆的方法。

2.2 方法二:基于堆的算法

建立一个大根堆,然后遍历数组中的所有元素,在遍历过程中:
如果堆的大小小于k,则直接将该元素进堆。如果堆的大小等于k,则比较该元素与堆顶元素的大小。由于是大根堆,堆顶元素肯定是堆中最大的元素。如果该元素比堆顶元素还要大,那么该元素不可能是前k小的元素之一(因为堆中已经有k个比它小的元素了),我们可以直接将该元素抛弃;如果该元素比堆顶元素小,我们知道堆顶元素已经不再属于前k小的元素之一了,我们将堆顶元素弹出,然后将该元素进堆。
在算法执行的过程中,保证堆的大小不大于k。这样算法执行完毕后,堆中的元素就是原数组前k小的元素。

数组前k小的元素
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
int main(){
    int n, k;
    cin >> n >> k; //n为数组元素个数
    vector<int> numbers(n);
    for(int i=0; i<n; i++) //逐个输入数组元素
        cin >> numbers[i];

    priority_queue<int, vector<int>, less<int>> heap; //建立大根堆
    for(int i=0; i<n; i++){
        if(heap.size() < k)
            heap.push(numbers[i]);
        else{
            int top = heap.top(); //获取堆顶元素
            if(top < numbers[i])
                continue;
            else if(top > numbers[i]){
                heap.pop();
                heap.push(numbers[i]);
            }
        }
    }
    while(heap.size() > 0){ //输出结果
        cout << heap.top() << " ";
        heap.pop();
    }
    return 0;
}

该方法不需要将所有数据载入内存,只需要在内存中保存一个大小为k的堆,所以更适合处理海量数据。因为每次调整堆的时间复杂度为O(logk),并且最坏情况下,每遍历到一个元素就需要对堆进行一次调整,所以该方法的时间复杂度为O(nlogk)。总体来讲,该方法适用于n较大但是k非常小的情况。

三、求长度可变数组的中位数

问题:有一个数组支持两种操作:一是向该数组中加入一个元素(int类型);二是求该数组中所有元素的中位数。请实现该数组。
通过本文的第一节,我们已经知道了计算大小为n的数组中位数的方法:如果n是奇数,那么中位数就是第n/2 + 1大的元素;如果n是偶数,中位数则是第n/2n/2 + 1大的元素的平均值。我们可以直接利用上面提到的基于partition的方法进行递归求解。
但是本问题中,数组的长度是逐渐增加的,我们不可能在每次往数组中添加元素之后,都利用O(n)的时间来求新的中位数。我们需要找到更快的方法。

3.1 最大最小堆法更新中位数

我们保存一个大根堆和一个小根堆,并且在往数组中添加元素之后,保证以下三个条件:
1、大根堆中的元素是整个数组中较小的一半元素;
2、小根堆中的元素是整个数组中较大的一半元素;
3、大根堆和小根堆的大小之差的绝对值不超过1。
具体来说,在添加一个元素num时:
如果num小于大根堆堆顶,说明num位于数组前半部分,将其放入大根堆中;
如果num大于小根堆堆顶,说明num位于数组后半部分,将其放入小根堆中;
如果num介于二者之间,则放入哪个堆中都是可以的。
然后观察两个堆大小之差的绝对值是否超过1,如果是,我们需要调节两个堆的平衡,具体做法为:从元素数量较多的堆中取出一个元素,放入另一个堆中。

3.2 最大最小堆法获取中位数

如果大根堆与小根堆的大小相差1,说明原数组大小为奇数。大小较大的那个堆的堆顶元素即为中位数。如果大根堆与小根堆的大小相等,说明原数组大小为偶数。两个堆堆顶元素的平均值即为中位数。

求长度可变数组的中位数
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
class MedianFinder {
private:
    priority_queue<int, vector<int>, less<int>> greaterRoot; //大根堆
    priority_queue<int, vector<int>, greater<int>> smallerRoot; //小根堆

public:
    MedianFinder(){
        greaterRoot = priority_queue<int, vector<int>, less<int>>();
        smallerRoot = priority_queue<int, vector<int>, greater<int>>();
    }

    // Adds a number into the data structure.
    void addNum(int num) {
        if(smallerRoot.empty()){
            greaterRoot.push(num);
            //观察大根堆的大小是否比小根堆多1,如果不是,则将大根堆的堆顶元素取出,放到小根堆中。
            if(greaterRoot.size() >= 2){
                int tmp = greaterRoot.top();
                greaterRoot.pop();
                smallerRoot.push(tmp);
            }
        }else{
            if(num < greaterRoot.top()){
                greaterRoot.push(num);
                if(greaterRoot.size() - smallerRoot.size() >= 2){
                    int tmp = greaterRoot.top();
                    greaterRoot.pop();
                    smallerRoot.push(tmp);
                }
            }else if(num > smallerRoot.top()){
                smallerRoot.push(num);
                if(smallerRoot.size() - greaterRoot.size() >= 2){
                    int tmp = smallerRoot.top();
                    smallerRoot.pop();
                    greaterRoot.push(tmp);
                }
            }else{
                int sizeA = greaterRoot.size();
                int sizeB = smallerRoot.size();
                if(sizeA <= sizeB)
                    greaterRoot.push(num);
                else
                    smallerRoot.push(num);
            }
        }
    }

    // Returns the median of current data stream
    double findMedian() {
        int sizeA = greaterRoot.size();
        int sizeB = smallerRoot.size();
        if(sizeA == sizeB)
            return (greaterRoot.top() + smallerRoot.top()) / 2.0;
        else if(sizeA > sizeB)
            return greaterRoot.top();
        else
            return smallerRoot.top();
    }
};

该算法每次添加元素的时间复杂度为O(logn),获取中位数的时间复杂度为O(1)

扩展:如果要求出数据流中第n/10大的元素该怎么做?展开

评论