【算法】TopN问题

最近看到了TopN的题和一些常用的场景,发现此类问题还是很有意思的。今天特来整理一下。

首先是问题描述:

给定一个数组,{3, 1, 8, 7, 4, 5, 2, …}, 数据规模为k, 要求找到最大(小)的前n个

很容易想到,我们把数据排序一下,然后直接取前n个不就完了。

1
2
arr.sort();
return Arrays.copyOf(arr, n);

如果是比较快的排序算法,以上时间复杂度为O(klogk)。

但是我们实际要关注的是我们真正需要的前N个,而不是对所有的数据进行排序,虽然最后得到的结果是有序的。

如果我们需要更优化的算法,而且不再强调结果的顺序,我们可以用堆(heap)来解决这个问题

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
/**
* 算法思想(以取最小的TopN为例)
* 将数组的前N个做成大顶堆,顶上元素最大。
* 遍历剩下的数组arr[n]...arr[length - 1]
* 如果arr[i] >= 顶,pass。
* 否则和顶交换,重新调整堆
*/
public static int[] findLeastTopN(int[] arr, int n) {
if (arr.length <= n) {
return arr;
}
initHeap(arr, n);
for (int i = n; i < arr.length; i++) {
if (arr[i] < arr[0]) {
swap(arr, 0, i);
adjust(arr, 0, n);
}
}
return Arrays.copyOf(arr, n);
}
// 初始化堆
public static void initHeap(int[] arr, int len) {
int lastHasSon = len / 2 - 1;
for (int i = lastHasSon; i >= 0; i--) {
adjust(arr, i, len);
}
}
// 交换
public static void swap(int arr[], int i, int k) {
int temp = arr[i];
arr[i] = arr[k];
arr[k] = temp;
}
// 调整
public static void adjust(int[] arr, int i, int len) {
int node = arr[i];
int l, r;
l = r = -1;
if (i * 2 + 1 < len) {
l = arr[i * 2 + 1];
}
if (i * 2 + 2 < len) {
r = arr[i * 2 + 2];
}
if (node >= r && node >= l) {
return;
} else if (l >= r) {
swap(arr, i, i * 2 + 1);
adjust(arr, i * 2 + 1, len);
} else {
swap(arr, i, i * 2 + 2);
adjust(arr, i * 2 + 2, len);
}
}

上面的算法取前n个作初始堆,然后依次扫描后k - n个,很容易得到,此算法时间复杂度为O(klogn), 显然比上一个算法更优,当n比较小时,复杂度趋近O(n)。

Java中的PriorityQueue是也是基于heap实现的,这里贴出利用优先级队列的解法。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
public static int[] findLargestTopN(int[] arr, int n) {
if (n >= arr.length) {
return arr;
}
// 重新实现一下comparator,保证数字大的优先级高
PriorityQueue<Integer> pq = new PriorityQueue<>((o1, o2) -> o2 - o1);
for (int i = 0; i < arr.length; i++) {
pq.offer(arr[i]);
}
int[] res = new int[n];
for (int i = 0; i < n; i++) {
res[i] = pq.poll();
}
return res;
}

接下来再介绍另外一种O(klogn)的解法,我们知道在快排中我们选取pivot,然后以pivot为界,将数据划分为大于pivot和小于pivot的两个部分。假如pivot所在的位置为i,虽然两侧数据无序,但是很容易得知,pivot是第i + 1大的元素,pivot的左侧就是最小的i个数,那其实只要按照快排的思路,找到第k大的轴,那么就找到了答案。

下面贴出找到第K大的数的代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
/**
先找到pivot,据此将区间分为两部分,如果privot的位置恰好是k的话,直接返回
否则如果pivot > k,说明第k大的数在pivot左边,反之在右边,继续划分即可。
如果区间只有一个了,说明找到了。
*/
public static int find(int[] arr, int s, int e, int k) {
if (s == e && s + 1 == k) {
return arr[s];
}
int pivot = arr[e];
int i = s;
for (int j = s; j <e; j++) {
if (arr[j] < pivot) {
int temp = arr[i];
arr[i++] = arr[j];
arr[j] = temp;
}
}
arr[e] = arr[i];
arr[i] = pivot;
if (i + 1 == k) {
return arr[i];
} else if (i + 1 > k) {
return find(arr, s, i - 1, k);
} else {
return find(arr, i + 1, e, k);
}
}

再把思路往下延伸一下,放在大数据量环境之下。

有1G的文件,文件中每行一个16字节之内的单词,内存只有16M,要求找出其中出现频率最高的n个,要求顺序。

显然不可能将这些数据全部加载进内存。分治算法是解决此类问题的一个可行方案。

首先我们将1G文件根据字母序或者哈希算法划分到多个桶(小文件)内,注意要将相同的单词划到同一个桶内,统计桶内单词的词频,生成(单词,频率)对,复杂度O(k)。然后根据TopN算法找出词频前N个另行存储,复杂度O(klogn)。为什么要取前n个呢,极端情况下,词频最高的n个集中在一个桶内。

最后合并结果,合并的过程可以使用归并排序。


大佬们见笑了,如有谬误,请在评论区指正,感谢。