并查集+LCA

如果两个rest stops距离小于等于$K$,那我们显然可以把这两个点合并到一块,因为这两个点可以相互到达,基于这样的思想,我们先做一个bfs,把相互能到达的点合并到一个集合中,但是一个点bfs的距离最远为$k$,两个点如果相距为$2k$的话,按照我们bfs的方法,也能走到一起,所以我们把一条边一分为2,每个点bfs的距离为$k$,这样就可以了。这样任意两点距离小于等于$2k$才能相互到达。

询问的时候如果两点间的距离小于等于$2k$,那么就直接YES
否则让两点分别往对方的方向走$k$个距离,利用并查集判断两点是否在一个集合中,在一个集合就是YES,否则就是NO

#include<bits/stdc++.h>

#define ls (rt << 1)
#define rs (rt << 1 | 1)
#define lowbit( x ) (x&(-x))
#define SZ( v ) ((int)(v).size())
#define All( v ) (v).begin(), (v).end()
#define mp( x , y ) make_pair(x,y)
#define fast ios::sync_with_stdio(0), cin.tie(0), cout.tie(0)
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair < int , int > P;
const int N = 4e5 + 10;
const int M = 1e7 + 10;
const int mod = 1e9 + 7;
const int inf = 0x3f3f3f3f;
const int INF = 2e9;
const int seed = 131;
vector<int> G[N];
int n , k ,r ,dis[N] , vis[N] ,dep[N] , fat[N][22] ,fa[N];
int get(int x){
    if(x==fa[x]) return x;
    return fa[x] = get(fa[x]);
}
void merge(int u , int v){
    u = get(u) , v = get(v);
    if(u != v ) fa[v] = u;
}
void dfs(int u , int pa){
    for(auto &v:G[u])
    {
        if(v==pa) continue;
        dep[v] = dep[u] + 1;
        fat[v][0] = u;
        for(int j = 1;j <= 20;++j){
            fat[v][j] = fat[ fat[v][j-1] ][j-1];
        }
        dfs(v,u);
    }
}
int cal(int u,int d)//u往上d距离的节点
{
    int d1 = d , u1 = u;
    for(int i = 0;i<=20;++i){
        if((d>>i)&1) {
            u = fat[u][i];
        }
    }
    //cout << u1 << " 's往上走" << d1 << "距离的节点为" << u << endl;
    return u;
}
int lca(int u , int v)
{
    if(dep[u] < dep[v]) swap(u,v);
    u = cal(u,dep[u]-dep[v]);
    if(u==v) return u;
    for(int i = 20 ; i >= 0 ; -- i){
        if(fat[u][i]!=fat[v][i]){
            u = fat[u][i];
            v = fat[v][i];
        }
    }
    return fat[u][0];
}
// u -> v 走了d
int walk(int u , int v , int Lca ,int d)
{
    if(d <= dep[u] - dep[Lca]) return cal(u,d);
    d -= dep[u] - dep[Lca] ;
    return cal(v,dep[v] - dep[Lca] - d);
}
bool solve(int u , int v)
{
    int Lca = lca(u,v);
    //cout << u << ' ' << v << ' ' << Lca << endl;
    if(dep[u] + dep[v] - 2*dep[Lca] <= 2*k ) return 1;
    int u1 = walk(u,v,Lca,k) , v1 = walk(v,u,Lca,k);
    return (get(u1) == get(v1)) ;
}
int main () {
    fast;
    cin >> n >> k >> r;
    int u , v;
    for(int i  = 1;i < n; ++ i){
        cin >> u >> v;
        G[u].emplace_back(n+i);
        G[n+i].emplace_back(u);
        G[n+i].emplace_back(v);
        G[v].emplace_back(n+i);
    }
    for(int i = 1;i <= n*2;++i) fa[i] = i;
    queue<int> Q;
    for(int i = 0;i < r ; ++ i) {
        cin >> u;
        vis[u] = 1;
        Q.push(u);
    }
    while(!Q.empty())
    {
        u = Q.front();
        Q.pop();
        //cout << "Q:" << u << endl;
        if(dis[u] >= k) break;
        for(auto &v:G[u])
        {
            //cout <<u << "->" << v << endl;
            merge(u,v);
            if(!vis[v]){
                Q.push(v);
                dis[v] = dis[u] + 1;
                vis[v] = 1;
            }
        }
    }
    dfs(1,0);
    int q;
    cin >> q;
    while(q--)
    {
        cin >> u >> v;
        cout << (solve(u,v) ? "YES" : "NO" )<< '\n';
    }
    return 0;
}

Last modification:February 21st, 2020 at 05:02 pm
如果觉得我的文章对你有用,请随意赞赏