题解 | #有序数组中位数#
有序数组中位数
https://www.nowcoder.com/practice/ca181c1bcfec4049a743b8d0dd09912e
这段代码是用来计算三个有序数组的中位数的。主要思路是通过找到第 k 个元素的方法来找到中位数。
在主函数 findMedianSortedArrays 中,首先计算出总元素个数 total。如果 total 是奇数,则通过调用 findKthElement 找到第 (total+1)/2 个元素作为中位数;如果 total 是偶数,则找到第 total/2 和 total/2+1 个元素,并计算其平均值作为中位数返回。
在辅助函数 findKthElement 中,使用三个指针 i、j、l 分别表示三个数组 nums1、nums2、nums3 的当前位置。循环查找第 k 个元素,直到找到。
在每次循环中,通过比较三个数组中当前位置的元素的大小,找到其中最小的元素。然后根据最小元素所在数组的情况,更新指针和 k 的值。具体更新方式是将最小元素所在数组的指针向后移动,并将 k 减去移动的步数。
因此,时间复杂度主要取决于 findKthElement 函数的执行次数。由于每次循环 k 的值都至少减半,所以时间复杂度应该是 O(log(min(m, n, p))),其中 m、n、p 分别表示三个数组的长度。
class Solution {
public:
/**
* 计算三个有序数组的中位数
*
* @param nums1 第一个有序数组
* @param nums2 第二个有序数组
* @param nums3 第三个有序数组
* @return 中位数
*/
double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2,
vector<int>& nums3) {
int total = nums1.size() + nums2.size() + nums3.size();
// 如果总元素个数为奇数,则找到第 (total+1)/2 个元素;如果总元素个数为偶数,则找到第 total/2 和 total/2+1 个元素
if (total % 2 == 1) {
return findKthElement(nums1, nums2, nums3, total / 2 + 1);
} else {
return (findKthElement(nums1, nums2, nums3, total / 2) + findKthElement(nums1,
nums2, nums3, total / 2 + 1)) / 2.0;
}
}
private:
/**
* 在三个有序数组中找到第 k 个元素
*
* @param nums1 第一个有序数组
* @param nums2 第二个有序数组
* @param nums3 第三个有序数组
* @param k 第 k 个元素
* @return 第 k 个元素
*/
double findKthElement(vector<int>& nums1, vector<int>& nums2,
vector<int>& nums3,
int k) {
int m = nums1.size(), n = nums2.size(), p = nums3.size();
int i = 0, j = 0, l =
0; // 分别表示在 nums1、nums2、nums3 中的当前位置
while (true) {
if (k == 1) { // 找到第 1 个元素,即最小的元素
int mi = 0x3f3f3f3f;
if (m && i < m) {
mi = min(mi, nums1[i]);
}
if (n && j < n) {
mi = min(mi, nums2[j]);
}
if (p && l < p) {
mi = min(mi, nums3[l]);
}
return mi;
}
// 每次比较三个数组中当前位置的元素,并排除掉其中最小的元素
int half = k / 2;
int ni = min(i + half, m) - 1;
int nj = min(j + half, n) - 1;
int nl = min(l + half, p) - 1;
int pivot1 = 0x3f3f3f3f, pivot2 = 0x3f3f3f3f, pivot3 = 0x3f3f3f3f;
if (m) {
pivot1 = nums1[ni];
}
if (n) {
pivot2 = nums2[nj];
}
if (p) {
pivot3 = nums3[nl];
}
if (pivot1 <= pivot2 && pivot1 <= pivot3) { //pivot1最小
k -= (ni - i + 1);
i = ni + 1;
} else if (pivot2 <= pivot1 && pivot2 <= pivot3) { //pivot2最小
k -= (nj - j + 1);
j = nj + 1;
} else { //pivot3最小
k -= (nl - l + 1);
l = nl + 1;
}
}
}
};