题解 | #两个有序数组间相加和的Topk问题#
// 牛客的测试链接: // https://www.nowcoder.com/practice/7201cacf73e7495aa5f88b223bbbf6d1 // 不要提交包信息,把import底下的类名改成Main,提交下面的代码可以直接通过 // 因为测试平台会卡空间,所以把set换成了动态加和减的结构
import java.util.Scanner; import java.util.Comparator; import java.util.HashSet; import java.util.PriorityQueue;
public class Main {
public static void main(String[] args) {
Scanner scanner = new Scanner(System.in);
int N = scanner.nextInt();
int K = scanner.nextInt();
int[] arr1 = new int[N];
int[] arr2 = new int[N];
for (int i = 0; i < N; i++) {
arr1[i] = scanner.nextInt();
}
for (int i = 0; i < N; i++) {
arr2[i] = scanner.nextInt();
}
int[] topK = topKSum(arr1,arr2,K);
for (int i = 0; i < K; i++) {
System.out.print(topK[i] + " ");
}
System.out.println();
scanner.close();
}
public static class Node{
public int index1;
public int index2;
public int sum;
public Node(int index1,int index2,int sum){
this.index1 = index1;
this.index2 = index2;
this.sum = sum; //arr1[index1]+arr2[index2]
}
}
//o2.sum - o1.sum : 降序 排序比较器
public static class NodeComparator implements Comparator<Node>{
@Override
public int compare(Node o1, Node o2) {
return o2.sum - o1.sum;
}
}
public static int[] topKSum(int[] arr1, int[] arr2, int topK) {
int N = arr1.length;
int M = arr2.length;
topK = Math.min(topK,N * M);
// res的长度可能是topK,也可能是N*M,向res填sum时,
// 需要将进行下标换算( i1 * M + i2),确保每个sum都是 arr1[i1] + arr2[i2]得到的唯一sum(sum可能相同,但组成sum的来源不同)
int[] res = new int[topK];
int resIndex = 0;
// 大根堆,最大元素放在堆顶
PriorityQueue<Node> maxHeap = new PriorityQueue<Node>(new NodeComparator());
HashSet<Long> set = new HashSet<>();
int i1 = N -1;
int i2 = M -1;
set.add(calcuIndex(i1,i2,M));
maxHeap.add(new Node(i1,i2,arr1[i1] + arr2[i2]));
while (resIndex != topK){
Node curNode = maxHeap.poll();
res[resIndex++] = curNode.sum;
i1 = curNode.index1;
i2 = curNode.index2;
set.remove(calcuIndex(i1,i2,M));
// 切记,一定要保证set集合和maxHeap中的数据要同步,maxHeap poll过元素,set中也一定要移除对应元素,避免脏数据
// 每个sum 都对应了一个resIndex,!set.contains(calcuIndex(i1-1,i2,M)可以判断当前得到的sum是否在之前遍历过
if (i1 - 1 >= 0 && !set.contains(calcuIndex(i1-1,i2,M))){
maxHeap.add(new Node(i1-1,i2,arr1[i1-1] + arr2[i2]));
set.add(calcuIndex(i1-1,i2,M));
}
if (i2 - 1 >= 0 && !set.contains(calcuIndex(i1,i2-1,M))){
maxHeap.add(new Node(i1,i2-1,arr1[i1] + arr2[i2-1]));
set.add(calcuIndex(i1,i2-1,M));
}
}
return res;
}
public static long calcuIndex(int i1,int i2,int M){
return (long) i1 * (long)M + (long)i2;
}
}