[LeetCode] 4. Median of Two Sorted Arrays

Posted by Minho Ryu on October 10, 2021 · 4 mins read

Problem:

Given two sorted arrays nums1 and nums2 of size m and n respectively, return the median of the two sorted arrays. The overall run time complexity should be $\mathrm{O(log(m+n))}$.

Example 1:
Input: nums1 = [1,3], nums2 = [2]
Output: 2.00000
Explanation: merged array = [1,2,3] and median is 2.

Example 2:
Input: nums1 = [1,2], nums2 = [3,4]
Output: 2.50000
Explanation: merged array = [1,2,3,4] and median is (2 + 3) / 2 = 2.5.

Example 3:
Input: nums1 = [0,0], nums2 = [0,0]
Output: 0.00000

Example 4:
Input: nums1 = [], nums2 = [1]
Output: 1.00000

Example 5:
Input: nums1 = [2], nums2 = []
Output: 2.00000


Solution:

Median의 정의를 살펴보면, "set을 두 개의 같은 subset들로 나누고 하나의 subset은 나머지 subset보다 항상 크다."이다. 따라서 두 개의 subset으로 나뉘었을 때 왼쪽 subset의 원소 갯수와 오른쪽 subset의 원소 갯수가 같거나 (짝수) 1개 적도록 만들고 (홀수) 왼쪽 subset의 최댓 값 (맨 오른쪽 원소)이 오른쪽 subset의 최솟 값 (맨 왼쪽 원소)보다 작으면 된다. 다시 문제를 보면 두 개의 분류된 Array가 주어졌을 때, 먼저 왼쪽에 오는 Array가 갯수가 작거나 같도록 길이를 비교해서 분류한다. 이 조건을 유지하면서 왼쪽 Array에 대해 binary search를 적용하면 $\mathrm{O(log(m+n))}$의 time complexity로 해결할 수 있다.


Code:

class Solution:
    def findMedianSortedArrays(self, nums1: List[int], nums2: List[int]) -> float:
        a, b = sorted((nums1, nums2), key=len)
        m, n = len(a), len(b)
        if n == 0:
            raise ValueError

        imin, imax, half_len = 0, m, (m + n + 1) // 2
        while imin <= imax:
            i = (imin + imax) // 2
            j = half_len - i
            if i < m and b[j-1] > a[i]:
                # i is too small, must increase it
                imin = i + 1
            elif i > 0 and a[i-1] > b[j]:
                # i is too big, must decrease it
                imax = i - 1
            else:
                # i is perfect
                if i == 0: max_of_left = b[j-1]
                elif j == 0: max_of_left = nums1[i-1]
                else: max_of_left = max(a[i-1], b[j-1])

                if (m + n) % 2 == 1:
                    return max_of_left

                if i == m: min_of_right = b[j]
                elif j == n: min_of_right = a[i]
                else: min_of_right = min(a[i], b[j])

                return (max_of_left + min_of_right) / 2.0