2015년 12월 23일 수요일

E. Minimum spanning tree for each edge

Tree, Graph, LCA(lca)


SOL)
mst의 값을 구하고 각 간선의 주어진 u, v의 경로중
지나는 간선의 가장 가중치가 큰 녀석을 빼고 각 간선의
가중치를 더해준게 답이다.
(mst에서 임의의 두 정점 u, v의 경로는 하나이다.)
가중치가 가장 큰 녀석을 구하는 과정에서 lca가 필요하다.
순수하게 간선을 하나씩 탐색하면 시간초과가 나게 된다.
즉 가장 큰 녀석을 빨리 구하는 방법이 필요한데 여기서
lca개념이 들어간다.
우리는 lca를 구하는 원리와 거의 같게 
heavy한 경로를 얻을 수 있다.
(주석도 참고)


cpp to html [-] Collapse
#include<stdio.h>
#include<vector>
#include<algorithm>
using namespace std;
#define mp make_pair
typedef long long ll;
const int N = 200010;
const int L = 20;

struct Edge{
    int u, v, w, idx;
    bool operator < (const Edge& comp){
        return w < comp.w;
    }
};


ll mst, ans[N];
int n, m;
int timer;
int tin[N], tout[N], par[N], up[N][L], heavy[N][L];
vector<pair<int, int> > adj[N];
Edge edge[N];

void dfs(int here, int prev, int cost){
    tin[here] = timer++; //언제 방문했는지 기록
    up[here][0] = prev;
    heavy[here][0] = cost;
    for (int i = 1; i < L; i++) {
        //here의 2^i번째 조상을 저장
        up[here][i] = up[up[here][i - 1]][i - 1];
        //here의 2^i 번째 까지 가는 거쳐가는 간선 중 가장 큰 값을 저장
        //here의 2^i 번째 조상까지의 경로라면 2^(i-1)를 포함한다.
        heavy[here][i] = max(heavy[here][i-1], heavy[up[here][i - 1]][i - 1]);

    }

    for (int i = 0; i < adj[here].size(); i++){
        int next = adj[here][i].first;
        int w = adj[here][i].second;
        if (next != prev){
            dfs(next, here, w);
        }
    }
    tout[here] = timer++; //언제 빠져나왔는지 기록
}

bool upper(int u, int v){
    return tin[u] <= tin[v] && tout[u] >= tout[v];
}

int get_lca(int u, int v){
    if (upper(u, v)) return u; //u가 lca라면
    if (upper(v, u)) return v; //v가 lca라면
    //설명하기 힘든데 모든 경우를 커버할 수 있음..
    for (int i = L - 1; i >= 0; i--){
        if (!upper(up[u][i], v)){
            u = up[u][i]; // up[u][2] == up[up[u][1]][1]
        }
    }
    return up[u][0];
}

int get_heavy(int u, int v){
    int res = 0;
    //u자체가 lca와 같다면 0을반환해야한다.
    //만약 이 if문이 없다면 u의 부모의 간선 가중치를 반환하게 된다.
    //lca를 구할 때 for문에 들어가기 전에 if문을 2개 쓰는데 그게  없는대신
    //이 if문이 필요하다.
    if (u != v){
        for (int i = L - 1; i >= 0; i--){
            if (!upper(up[u][i], v)){
                res = max(res, heavy[u][i]);
                u = up[u][i];
            }
        }
        res = max(res, heavy[u][0]);
    }
    return res;
}
int cal(int u, int v){
    int lca = get_lca(u, v);
    int res = max(get_heavy(u, lca), get_heavy(v, lca));
    return res;
}


int find(int v){
    if (v == par[v]) return v;
    else return par[v] = find(par[v]);
}

void merge(int u, int v){
    u = find(u), v = find(v);
    par[v] = u;
}

void make_kruskal(){
    sort(edge, edge + m);

    for (int i = 0; i < m; i++){
        int u = edge[i].u;
        int v = edge[i].v;
        int w = edge[i].w;
        if (find(u) != find(v)){
            merge(u, v);
            mst += w;
            adj[u].push_back(mp(v, w));
            adj[v].push_back(mp(u, w));
   nbsp;     }
    }
}

void init(){
    scanf("%d %d", &n, &m);

    for (int i = 0; i < N; i++) par[i] = i;
    for (int i = 0; i < m; i++){
        int u, v, w; scanf("%d %d %d", &u, &v, &w);
        edge[i].u = u;
        edge[i].v = v;
        edge[i].w = w;
        edge[i].idx = i;
    }
}

int main(){
    freopen("input.txt", "r", stdin);
    init();
    make_kruskal();
    dfs(1, 1, 0);
    for (int i = 0; i < m; i++){
        int u = edge[i].u;
        int v = edge[i].v;
        int w = edge[i].w;
        int idx = edge[i].idx;

        int res = cal(u, v);
        ans[idx] = mst - res + w;
    }
    for (int i = 0; i < m; i++) printf("%I64d\n", ans[i]);
    return 0;
}

댓글 없음:

댓글 쓰기