中位数堆

最大堆负责左半、最小堆负责右半,保持左半跟右半等大或大1。见ctci p595

class MedianFinder {
    priority_queue<int> lo; // 左半边最大堆
    priority_queue<int, vector<int>, greater<int>> hi; // 右半边最小堆
public:
    // 保持左半堆与右半堆等大或大1
    void addNum(int num) {
        if (lo.size() == hi.size()) {
            hi.push(num);
            lo.push(hi.top());
            hi.pop();
        } else { // 左半堆大1
            lo.push(num);
            hi.push(lo.top());
            lo.pop();
        }
    }

    double findMedian() {
        if (lo.empty()) return 0;
        if (lo.size() == hi.size()) return (lo.top() + hi.top()) * 0.5;
        return lo.top();
    }
};

中位数队列

class MedianQueue {
    // 最大堆负责左半,最小堆负责右半
    // 因为c++的priority_queue没法随意删除元素,用multiset替代
    multiset<int, greater<int>> lo;
    multiset<int> hi;

    void moveOneToLo() {
        lo.insert(*hi.begin());
        hi.erase(hi.begin());
    }
    
    void moveOneToHi() {
        hi.insert(*lo.begin());
        lo.erase(lo.begin());        
    }
public:
    // 保证lo比hi大1或等大
    void push(int num) {
        if (lo.size() > hi.size()) {
            lo.insert(num);
            moveOneToHi();
        } else {
            hi.insert(num);
            moveOneToLo();
        }
    }
    
    void remove(int num) {
        if (num <= *lo.begin()) {
            lo.erase(lo.find(num));
            if (lo.size() < hi.size()) {
                moveOneToLo();
            }
        } else {
            hi.erase(hi.find(num));
            if (lo.size() - hi.size() > 1) {
                moveOneToHi();
            }
        }
    }
    
    double median() {
        if (lo.size() > hi.size()) {
            return *lo.begin();
        } else {
            return ((double)*lo.begin() + *hi.begin()) / 2;
        }
    }
};

两不等长有序数组的中位数

double findMedianSortedArrays(vector<int>& a, vector<int>& b) {
	// a分成a[0..i-1]、a[i..M-1],b分成b[0..j-1]、b[j..N-1],
	// 令k=(M+N+1)/2,划分i要满足i+j=k,且a[i-1]<=b[j] && b[j-1]<=a[i]
	// 1.要在a、b的较短数组中搜索划分i
	//  由0<=i<=M(i==M时后半为空),j=k-i => k-M<=j<=k;要使0<=j<=N,需要k-M>=0,k<=N => M<=N
	// 2.最终M+N为奇数时,lMax=max(a[i-1],b[j-1])为中位数;
	//  M+N为偶数时,rMin=min(a[i],b[j]),(lMax+rMin)/2为中位数
	const int M = a.size(), N = b.size();
	if (M > N) return findMedianSortedArrays(b, a);

	const int k = (M + N + 1) / 2;
	int lo = 0, hi = M;
	while (lo <= hi) {
		int i = lo + (hi - lo) / 2; // 划分点i
		int j = k - i;
		int la = (i > 0) ? a[i - 1] : INT_MIN;
		int lb = (j > 0) ? b[j - 1] : INT_MIN;
		int ra = (i < M) ? a[i] : INT_MAX;
		int rb = (j < N) ? b[j] : INT_MAX;

		if (la > rb) {  // 划分点i太靠右了,要往左移,排除划分点[i..]
			hi = i - 1;
		} else if (lb > ra) {
			lo = i + 1;
		} else {  // 有效的划分
			int lMax = max(la, lb);
			if ((M + N) % 2 == 1) return lMax;
			int rMin = min(ra, rb);
			return (lMax + rMin) * 0.5;
		}
	}
	return -1;
}

两等长有序数组的上中位数

int findMedianinTwoSortedAray(vector<int>& arr1, vector<int>& arr2) {
	const int N = arr1.size();  // 上中位数是在合并数组中第N小的数
	int lo = 0, hi = N - 1;
	while (lo < hi) {
		int mid1 = lo + (hi - lo) / 2;
		// 划分出N个元素 arr1[0..mid1]、arr2[0..mid2],有 mid1+1 + mid2+1 == N
		int mid2 = N - 2 - mid1;
		int a1 = arr1[mid1], a2 = arr2[mid2];
		if (a1 < a2) {
			// 一共划分出N个元素,arr1[mid1]较小,不可能是第N个元素,排除 arr1[0..mid1]
			lo = mid1 + 1;
		} else if (a1 > a2) {
			// N个元素的划分之外的 arr1[mid1+1..],可以排除
			hi = mid1;
		} else {
			return a1;
		}
	}
	// 合并数组中 lo+x == N-1,x = N-1-lo
	return min(arr1[lo], arr2[N - 1 - lo]);
}