最小生成树MST

  • https://leetcode.com/problems/min-cost-to-connect-all-points/

kruskal加边法

并查集+优先队列。各个顶点先自成集合,按边权重升序遍历边。若边的两个顶点属于不同集合,则选择该边,并将两集合合并。

int minCostConnectPoints(vector<vector<int>>& points) {
    // 最小生成树的Kruskal算法,并查集+优先队列
    // 完全连接E=O(N^2),所以O(ElgE)=O(N^2*2N)=O(N^3)
    const int N = points.size();
    vector<int> uf(N);
    iota(begin(uf), end(uf), 0);

    using arr3 = array<int, 3>;
    auto cmp = [](arr3 &a, arr3 &b) {
        return a[0] > b[0]; // 最小堆
    };        
    priority_queue<arr3, vector<arr3>, decltype(cmp)> pq(cmp);
    for (int i = 0; i < N; i++) {
        for (int j = i + 1; j < N; j++) {
            int cost = distance(i, j, points);
            pq.push({cost, i, j});
        }
    }

    int ans = 0, cnt = 0;
    while (!pq.empty() && cnt < N - 1) { // MST有N-1条边
        const auto [cost, x, y] = pq.top(); pq.pop();
        // unite
        int px = find(x, uf), py = find(y, uf);
        if (px != py) {
            uf[py] = px;
            ans += cost;
            cnt++;
        }
    }
    return ans;
}

int find(int x, vector<int> &uf) {
    if (uf[x] != x) uf[x] = find(uf[x], uf);
    return uf[x];
}

int distance(int u, int v, vector<vector<int>> &points) {
    auto &pu = points[u], &pv = points[v];
    return abs(pu[0] - pv[0]) + abs(pu[1] - pv[1]);
}

prim加点法

当把顶点u加入MST时,松弛所有未访问顶点v,d[v]=min(d[v],cost(u,v))d[v]表示顶点v到最小生成树MST的距离。 而dijkstra算法用d[v]=min(d[v],dist[u]+cost(u,v))来松弛,d[v]表示顶点v到源点的距离,区别只有这一点。

int minCostConnectPoints(vector<vector<int>>& points) {
    // 最小生成树的Prim算法,图遍历+优先队列
    // O(N^2)
    const int N = points.size();
    vector<int> dist(N, INT_MAX); // 各节点到最小生成树的距离
    dist[0] = 0;

    using arr2 = array<int, 2>; // [dist, idx]
    auto cmp = [](arr2 &a, arr2 &b) {
        return a[0] > b[0]; // 最小堆
    };
    priority_queue<arr2, vector<arr2>, decltype(cmp)> pq(cmp);
    pq.push({dist[0], 0});

    int ans = 0, cnt = 0;
    vector<int> visited(N, false);
    while (!pq.empty() && cnt < N) { // MST有N个顶点
        auto [d, u] = pq.top();  pq.pop();
        if (visited[u]) continue;
        visited[u] = true;
        ans += d;
        cnt++;

        for (int v = 0; v < N; v++) {
            if (visited[v]) continue;
            int newdist = distance(u, v, points);
            if (newdist < dist[v]) {
                dist[v] = newdist;
                pq.push({newdist, v});
            }
        }
    }
    return ans;
}

int distance(int u, int v, vector<vector<int>> &points) {
    auto &pu = points[u], &pv = points[v];
    return abs(pu[0] - pv[0]) + abs(pu[1] - pv[1]);
}