TopK数的解法汇总

寻找第K大

https://www.nowcoder.com/practice/e016ad9b7f0b45048c58a9f27ba618bf?tpId=190&&tqId=35209&rp=1&ru=/ta/job-code-high-rd&qru=/ta/job-code-high-rd/question-ranking

TopK问题,不管是求前K大/前K小/第K大/第K小等,都有下面四种解法:

  1. O(N):用快排变形最最最高效解决TopK问题

  2. O(NlogK):大根堆(前K小)/小根堆(前K大)

  3. O(NlogK):二叉搜索树

  4. O(N): 对于数据范围有限的情况可以直接计数排序O(N)高效解决

我们上一道例题:
图片说明


  • 方法一 快排
    时间复杂度为O(n)

    class Solution {
      public int[] getLeastNumbers(int[] arr, int k) {
          if(k == 0 || arr.length == 0){
              return new int[0];
          }
          //因为我们要的是前k个数,对应下标为k-1
          return quickSort(arr, 0, arr.length - 1, k - 1);
      }
    
      public static int[] quickSort(int[] arr, int L, int R, int k){
          int i = partition(arr, L, R);
          //恰好相等,说明左边已经有k个较小的数了
          if(i == k){
              return Arrays.copyOf(arr, i+1);
          }
    
          return i > k ? quickSort(arr, L, i-1, k) : quickSort(arr, i+1, R, k);
      }
    
      //填坑的过程
      public static int partition(int[] arr, int L, int R){
          int base = arr[L];
          int l = L, r = R;
          while(l < r){
              while(l < r && arr[r] >= base){
                  r--;
              }
              if(l < r){
                  arr[l] = arr[r];
                  l++;
              }
              while(l < r && arr[l] < base){
                  l++;
              }
              if(l < r){
                  arr[r] = arr[l];
                  r--;
              }
    
          }
          //呜呜呜一定要记得填回去啊 
          arr[l] = base;
          return l;
      }
    }

    因为我们是要找下标为k的元素,第一次切分的时候需要遍历整个数组(0 ~ n)找到了下标是j的元素,假如k比j小的话,那么我们下次切分只要遍历数组(0~k-1)的元素就行啦,反之如果k比j大的话,那下次切分只要遍历数组(k+1~n)的元素就行啦,总之平均情况下,可以看作每次调用partition遍历的元素数目都是上一次遍历的1/2,因此时间复杂度是N + N/2 + N/4 + ... + N/N = 2N, 因此时间复杂度是O(N)。


  • 方法二 大根堆:
    时间复杂度O(NlongN)
    用堆时间复杂度会比快排要慢很多,但是Java提供了现成的PriorityQueue(默认小根堆),索引代码实现起来很简单。
    本题是求前K小,因此用一个容量为K的大根堆(每次poll出最大的数,那堆中保留的就是前K个小的数)

    class Solution {
      public int[] getLeastNumbers(int[] arr, int k) {
          if(k == 0 || arr.length == 0){
              return new int[0];
          }
    
          //创建一个堆 重写一下比较器
          Queue<Integer> pq = new PriorityQueue<>((v1, v2)->v2 - v1);
    
          //遍历数组
          for(int num : arr){
              //构造一个k大的大根堆  O(N)
              if(pq.size() < k){
                  pq.offer(num); 
              }else{ 
                  if(num < pq.peek()){
                      pq.poll(); //超过K,调整堆 O(NlongN)
                      pq.offer(num);
                  }
              }
          }
          int[] res = new int[pq.size()];
          int index = 0;
          for(int num : pq){
              res[index++] = num; 
          }
    
          return res;
      }
    }

方法三 BST
时间复杂度 O(NlogN)
因为有重复的数字,所以用的是TreeMap而不是TreeSet(有的语言的标准库自带TreeMultiset,也是可以的)。TreeMap的key是数字,value是该数字的个数。我们遍历数组中的数字,维护一个数字总个数为K的TreeMap,每遍历一个元素:

  1. 若目前map中数字个数小于K,则将map中当前数字对应的个数+1;

  2. 否则,判断当前数字与map中最大的数字的大小关系:若当前数字大于等于map中的最大数字,就直接跳过该数字;若当前数字小于map中的最大数字,则将map中当前数字对应的个数+1,并将map中最大数字对应的个数减1.

class Solution {
    public int[] getLeastNumbers(int[] arr, int k) {
        if (k == 0 || arr.length == 0) {
            return new int[0];
        }
        // TreeMap的key是数字, value是该数字的个数。
        // cnt表示当前map总共存了多少个数字。
        TreeMap<Integer, Integer> map = new TreeMap<>();
        int cnt = 0;
        for (int num: arr) {
            // 1. 遍历数组,若当前map中的数字个数小于k,则map中当前数字对应个数+1
            if (cnt < k) {
                map.put(num, map.getOrDefault(num, 0) + 1);
                cnt++;
                continue;
            } 
            // 2. 否则,取出map中最大的Key(即最大的数字), 判断当前数字与map中最大数字的大小关系:
            //    若当前数字比map中最大的数字还大(或等于),就直接忽略;
            //    若当前数字比map中最大的数字小,则将当前数字加入map中,并将map中的最大数字的个数-1。
            Map.Entry<Integer, Integer> entry = map.lastEntry();
            if (entry.getKey() > num) {
                map.put(num, map.getOrDefault(num, 0) + 1);
                if (entry.getValue() == 1) {
                    map.pollLastEntry();
                } else {
                    map.put(entry.getKey(), entry.getValue() - 1);
                }
            }

        }

        // 最后返回map中的元素
        int[] res = new int[k];
        int idx = 0;
        for (Map.Entry<Integer, Integer> entry: map.entrySet()) {
            int freq = entry.getValue();
            while (freq-- > 0) {
                res[idx++] = entry.getKey();
            }
        }
        return res;
    }
}

方法四 计数排序
时间复杂度 O(N)
仅适用于数据范围有限的topK题

class Solution {
    public int[] getLeastNumbers(int[] arr, int k) {
        if (k == 0 || arr.length == 0) {
            return new int[0];
        }
        // 统计每个数字出现的次数
        int[] counter = new int[10001];
        for (int num: arr) {
            counter[num]++;
        }
        // 根据counter数组从头找出k个数作为返回结果
        int[] res = new int[k];
        int idx = 0;
        for (int num = 0; num < counter.length; num++) {
            while (counter[num]-- > 0 && idx < k) {
                res[idx++] = num;
            }
            if (idx == k) {
                break;
            }
        }
        return res;
    }
}
全部评论

相关推荐

10-10 17:54
点赞 评论 收藏
分享
10-15 03:05
门头沟学院 Java
CADILLAC_:凯文:我的邮箱是死了吗?
点赞 评论 收藏
分享
点赞 1 评论
分享
牛客网
牛客企业服务