题解 | #有序数组中位数#

有序数组中位数

https://www.nowcoder.com/practice/ca181c1bcfec4049a743b8d0dd09912e

这段代码是用来计算三个有序数组的中位数的。主要思路是通过找到第 k 个元素的方法来找到中位数。

在主函数 findMedianSortedArrays 中,首先计算出总元素个数 total。如果 total 是奇数,则通过调用 findKthElement 找到第 (total+1)/2 个元素作为中位数;如果 total 是偶数,则找到第 total/2 和 total/2+1 个元素,并计算其平均值作为中位数返回。

在辅助函数 findKthElement 中,使用三个指针 i、j、l 分别表示三个数组 nums1、nums2、nums3 的当前位置。循环查找第 k 个元素,直到找到。

在每次循环中,通过比较三个数组中当前位置的元素的大小,找到其中最小的元素。然后根据最小元素所在数组的情况,更新指针和 k 的值。具体更新方式是将最小元素所在数组的指针向后移动,并将 k 减去移动的步数。

因此,时间复杂度主要取决于 findKthElement 函数的执行次数。由于每次循环 k 的值都至少减半,所以时间复杂度应该是 O(log(min(m, n, p))),其中 m、n、p 分别表示三个数组的长度。

class Solution {
  public:
    /**
     * 计算三个有序数组的中位数
     *
     * @param nums1 第一个有序数组
     * @param nums2 第二个有序数组
     * @param nums3 第三个有序数组
     * @return 中位数
     */
    double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2,
                                  vector<int>& nums3) {
        int total = nums1.size() + nums2.size() + nums3.size();
        // 如果总元素个数为奇数,则找到第 (total+1)/2 个元素;如果总元素个数为偶数,则找到第 total/2 和 total/2+1 个元素
        if (total % 2 == 1) {
            return findKthElement(nums1, nums2, nums3, total / 2 + 1);
        } else {
            return (findKthElement(nums1, nums2, nums3, total / 2) + findKthElement(nums1,
                    nums2, nums3, total / 2 + 1)) / 2.0;
        }
    }

  private:
    /**
     * 在三个有序数组中找到第 k 个元素
     *
     * @param nums1 第一个有序数组
     * @param nums2 第二个有序数组
     * @param nums3 第三个有序数组
     * @param k 第 k 个元素
     * @return 第 k 个元素
     */
    double findKthElement(vector<int>& nums1, vector<int>& nums2,
                          vector<int>& nums3,
                          int k) {

        int m = nums1.size(), n = nums2.size(), p = nums3.size();
        int i = 0, j = 0, l =
                           0; // 分别表示在 nums1、nums2、nums3 中的当前位置
        while (true) {
            
            if (k == 1) { // 找到第 1 个元素,即最小的元素
                int mi = 0x3f3f3f3f;
                if (m && i < m) {
                    mi = min(mi, nums1[i]);
                }
                if (n && j < n) {
                    mi = min(mi, nums2[j]);
                }
                if (p && l < p) {
                    mi = min(mi, nums3[l]);
                }
                return mi;
            }

            // 每次比较三个数组中当前位置的元素,并排除掉其中最小的元素
            int half = k / 2;
            int ni = min(i + half, m) - 1;
            int nj = min(j + half, n) - 1;
            int nl = min(l + half, p) - 1;
            int pivot1 = 0x3f3f3f3f, pivot2 = 0x3f3f3f3f, pivot3 = 0x3f3f3f3f;

            if (m) {
                pivot1 = nums1[ni];
            }
            if (n) {
                pivot2 = nums2[nj];
            }
            if (p) {
                pivot3 = nums3[nl];
            }

            if (pivot1 <= pivot2 && pivot1 <= pivot3) {  //pivot1最小
                k -= (ni - i + 1);
                i = ni + 1;
            } else if (pivot2 <= pivot1 && pivot2 <= pivot3) { //pivot2最小
                k -= (nj - j + 1);
                j = nj + 1;
            } else { //pivot3最小
                k -= (nl - l + 1);
                l = nl + 1;
            }

        }
    }
};

全部评论

相关推荐

1 收藏 评论
分享
牛客网
牛客企业服务